cloud.rs

   1use anthropic::AnthropicModelMode;
   2use anyhow::{Context as _, Result, anyhow};
   3use client::{Client, ModelRequestUsage, UserStore, zed_urls};
   4use futures::{
   5    AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
   6};
   7use google_ai::GoogleModelMode;
   8use gpui::{
   9    AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
  10};
  11use http_client::http::{HeaderMap, HeaderValue};
  12use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
  13use language_model::{
  14    AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
  15    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
  16    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  17    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
  18    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
  19    ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
  20};
  21use proto::Plan;
  22use release_channel::AppVersion;
  23use schemars::JsonSchema;
  24use serde::{Deserialize, Serialize, de::DeserializeOwned};
  25use settings::SettingsStore;
  26use smol::io::{AsyncReadExt, BufReader};
  27use std::pin::Pin;
  28use std::str::FromStr as _;
  29use std::sync::Arc;
  30use std::time::Duration;
  31use thiserror::Error;
  32use ui::{TintColor, prelude::*};
  33use util::{ResultExt as _, maybe};
  34use zed_llm_client::{
  35    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
  36    CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
  37    ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
  38    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
  39    TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  40};
  41
  42use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
  43use crate::provider::google::{GoogleEventMapper, into_google};
  44use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
  45
  46const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
  47const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
  48
  49#[derive(Default, Clone, Debug, PartialEq)]
  50pub struct ZedDotDevSettings {
  51    pub available_models: Vec<AvailableModel>,
  52}
  53
  54#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  55#[serde(rename_all = "lowercase")]
  56pub enum AvailableProvider {
  57    Anthropic,
  58    OpenAi,
  59    Google,
  60}
  61
  62#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  63pub struct AvailableModel {
  64    /// The provider of the language model.
  65    pub provider: AvailableProvider,
  66    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
  67    pub name: String,
  68    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
  69    pub display_name: Option<String>,
  70    /// The size of the context window, indicating the maximum number of tokens the model can process.
  71    pub max_tokens: usize,
  72    /// The maximum number of output tokens allowed by the model.
  73    pub max_output_tokens: Option<u64>,
  74    /// The maximum number of completion tokens allowed by the model (o1-* only)
  75    pub max_completion_tokens: Option<u64>,
  76    /// Override this model with a different Anthropic model for tool calls.
  77    pub tool_override: Option<String>,
  78    /// Indicates whether this custom model supports caching.
  79    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
  80    /// The default temperature to use for this model.
  81    pub default_temperature: Option<f32>,
  82    /// Any extra beta headers to provide when using the model.
  83    #[serde(default)]
  84    pub extra_beta_headers: Vec<String>,
  85    /// The model's mode (e.g. thinking)
  86    pub mode: Option<ModelMode>,
  87}
  88
  89#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  90#[serde(tag = "type", rename_all = "lowercase")]
  91pub enum ModelMode {
  92    #[default]
  93    Default,
  94    Thinking {
  95        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  96        budget_tokens: Option<u32>,
  97    },
  98}
  99
 100impl From<ModelMode> for AnthropicModelMode {
 101    fn from(value: ModelMode) -> Self {
 102        match value {
 103            ModelMode::Default => AnthropicModelMode::Default,
 104            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 105        }
 106    }
 107}
 108
 109pub struct CloudLanguageModelProvider {
 110    client: Arc<Client>,
 111    state: gpui::Entity<State>,
 112    _maintain_client_status: Task<()>,
 113}
 114
 115pub struct State {
 116    client: Arc<Client>,
 117    llm_api_token: LlmApiToken,
 118    user_store: Entity<UserStore>,
 119    status: client::Status,
 120    accept_terms: Option<Task<Result<()>>>,
 121    models: Vec<Arc<zed_llm_client::LanguageModel>>,
 122    default_model: Option<Arc<zed_llm_client::LanguageModel>>,
 123    default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
 124    recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
 125    _fetch_models_task: Task<()>,
 126    _settings_subscription: Subscription,
 127    _llm_token_subscription: Subscription,
 128}
 129
 130impl State {
 131    fn new(
 132        client: Arc<Client>,
 133        user_store: Entity<UserStore>,
 134        status: client::Status,
 135        cx: &mut Context<Self>,
 136    ) -> Self {
 137        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 138
 139        Self {
 140            client: client.clone(),
 141            llm_api_token: LlmApiToken::default(),
 142            user_store,
 143            status,
 144            accept_terms: None,
 145            models: Vec::new(),
 146            default_model: None,
 147            default_fast_model: None,
 148            recommended_models: Vec::new(),
 149            _fetch_models_task: cx.spawn(async move |this, cx| {
 150                maybe!(async move {
 151                    let (client, llm_api_token) = this
 152                        .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
 153
 154                    loop {
 155                        let status = this.read_with(cx, |this, _cx| this.status)?;
 156                        if matches!(status, client::Status::Connected { .. }) {
 157                            break;
 158                        }
 159
 160                        cx.background_executor()
 161                            .timer(Duration::from_millis(100))
 162                            .await;
 163                    }
 164
 165                    let response = Self::fetch_models(client, llm_api_token).await?;
 166                    cx.update(|cx| {
 167                        this.update(cx, |this, cx| {
 168                            let mut models = Vec::new();
 169
 170                            for model in response.models {
 171                                models.push(Arc::new(model.clone()));
 172
 173                                // Right now we represent thinking variants of models as separate models on the client,
 174                                // so we need to insert variants for any model that supports thinking.
 175                                if model.supports_thinking {
 176                                    models.push(Arc::new(zed_llm_client::LanguageModel {
 177                                        id: zed_llm_client::LanguageModelId(
 178                                            format!("{}-thinking", model.id).into(),
 179                                        ),
 180                                        display_name: format!("{} Thinking", model.display_name),
 181                                        ..model
 182                                    }));
 183                                }
 184                            }
 185
 186                            this.default_model = models
 187                                .iter()
 188                                .find(|model| model.id == response.default_model)
 189                                .cloned();
 190                            this.default_fast_model = models
 191                                .iter()
 192                                .find(|model| model.id == response.default_fast_model)
 193                                .cloned();
 194                            this.recommended_models = response
 195                                .recommended_models
 196                                .iter()
 197                                .filter_map(|id| models.iter().find(|model| &model.id == id))
 198                                .cloned()
 199                                .collect();
 200                            this.models = models;
 201                            cx.notify();
 202                        })
 203                    })??;
 204
 205                    anyhow::Ok(())
 206                })
 207                .await
 208                .context("failed to fetch Zed models")
 209                .log_err();
 210            }),
 211            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 212                cx.notify();
 213            }),
 214            _llm_token_subscription: cx.subscribe(
 215                &refresh_llm_token_listener,
 216                |this, _listener, _event, cx| {
 217                    let client = this.client.clone();
 218                    let llm_api_token = this.llm_api_token.clone();
 219                    cx.spawn(async move |_this, _cx| {
 220                        llm_api_token.refresh(&client).await?;
 221                        anyhow::Ok(())
 222                    })
 223                    .detach_and_log_err(cx);
 224                },
 225            ),
 226        }
 227    }
 228
 229    fn is_signed_out(&self) -> bool {
 230        self.status.is_signed_out()
 231    }
 232
 233    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 234        let client = self.client.clone();
 235        cx.spawn(async move |state, cx| {
 236            client
 237                .authenticate_and_connect(true, &cx)
 238                .await
 239                .into_response()?;
 240            state.update(cx, |_, cx| cx.notify())
 241        })
 242    }
 243
 244    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
 245        self.user_store
 246            .read(cx)
 247            .current_user_has_accepted_terms()
 248            .unwrap_or(false)
 249    }
 250
 251    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
 252        let user_store = self.user_store.clone();
 253        self.accept_terms = Some(cx.spawn(async move |this, cx| {
 254            let _ = user_store
 255                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
 256                .await;
 257            this.update(cx, |this, cx| {
 258                this.accept_terms = None;
 259                cx.notify()
 260            })
 261        }));
 262    }
 263
 264    async fn fetch_models(
 265        client: Arc<Client>,
 266        llm_api_token: LlmApiToken,
 267    ) -> Result<ListModelsResponse> {
 268        let http_client = &client.http_client();
 269        let token = llm_api_token.acquire(&client).await?;
 270
 271        let request = http_client::Request::builder()
 272            .method(Method::GET)
 273            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 274            .header("Authorization", format!("Bearer {token}"))
 275            .body(AsyncBody::empty())?;
 276        let mut response = http_client
 277            .send(request)
 278            .await
 279            .context("failed to send list models request")?;
 280
 281        if response.status().is_success() {
 282            let mut body = String::new();
 283            response.body_mut().read_to_string(&mut body).await?;
 284            return Ok(serde_json::from_str(&body)?);
 285        } else {
 286            let mut body = String::new();
 287            response.body_mut().read_to_string(&mut body).await?;
 288            anyhow::bail!(
 289                "error listing models.\nStatus: {:?}\nBody: {body}",
 290                response.status(),
 291            );
 292        }
 293    }
 294}
 295
 296impl CloudLanguageModelProvider {
 297    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
 298        let mut status_rx = client.status();
 299        let status = *status_rx.borrow();
 300
 301        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
 302
 303        let state_ref = state.downgrade();
 304        let maintain_client_status = cx.spawn(async move |cx| {
 305            while let Some(status) = status_rx.next().await {
 306                if let Some(this) = state_ref.upgrade() {
 307                    _ = this.update(cx, |this, cx| {
 308                        if this.status != status {
 309                            this.status = status;
 310                            cx.notify();
 311                        }
 312                    });
 313                } else {
 314                    break;
 315                }
 316            }
 317        });
 318
 319        Self {
 320            client,
 321            state: state.clone(),
 322            _maintain_client_status: maintain_client_status,
 323        }
 324    }
 325
 326    fn create_language_model(
 327        &self,
 328        model: Arc<zed_llm_client::LanguageModel>,
 329        llm_api_token: LlmApiToken,
 330    ) -> Arc<dyn LanguageModel> {
 331        Arc::new(CloudLanguageModel {
 332            id: LanguageModelId(SharedString::from(model.id.0.clone())),
 333            model,
 334            llm_api_token: llm_api_token.clone(),
 335            client: self.client.clone(),
 336            request_limiter: RateLimiter::new(4),
 337        })
 338    }
 339}
 340
 341impl LanguageModelProviderState for CloudLanguageModelProvider {
 342    type ObservableEntity = State;
 343
 344    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 345        Some(self.state.clone())
 346    }
 347}
 348
 349impl LanguageModelProvider for CloudLanguageModelProvider {
 350    fn id(&self) -> LanguageModelProviderId {
 351        PROVIDER_ID
 352    }
 353
 354    fn name(&self) -> LanguageModelProviderName {
 355        PROVIDER_NAME
 356    }
 357
 358    fn icon(&self) -> IconName {
 359        IconName::AiZed
 360    }
 361
 362    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 363        let default_model = self.state.read(cx).default_model.clone()?;
 364        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 365        Some(self.create_language_model(default_model, llm_api_token))
 366    }
 367
 368    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
 369        let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
 370        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 371        Some(self.create_language_model(default_fast_model, llm_api_token))
 372    }
 373
 374    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 375        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 376        self.state
 377            .read(cx)
 378            .recommended_models
 379            .iter()
 380            .cloned()
 381            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 382            .collect()
 383    }
 384
 385    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 386        let llm_api_token = self.state.read(cx).llm_api_token.clone();
 387        self.state
 388            .read(cx)
 389            .models
 390            .iter()
 391            .cloned()
 392            .map(|model| self.create_language_model(model, llm_api_token.clone()))
 393            .collect()
 394    }
 395
 396    fn is_authenticated(&self, cx: &App) -> bool {
 397        !self.state.read(cx).is_signed_out()
 398    }
 399
 400    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 401        Task::ready(Ok(()))
 402    }
 403
 404    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
 405        cx.new(|_| ConfigurationView {
 406            state: self.state.clone(),
 407        })
 408        .into()
 409    }
 410
 411    fn must_accept_terms(&self, cx: &App) -> bool {
 412        !self.state.read(cx).has_accepted_terms_of_service(cx)
 413    }
 414
 415    fn render_accept_terms(
 416        &self,
 417        view: LanguageModelProviderTosView,
 418        cx: &mut App,
 419    ) -> Option<AnyElement> {
 420        render_accept_terms(self.state.clone(), view, cx)
 421    }
 422
 423    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
 424        Task::ready(Ok(()))
 425    }
 426}
 427
 428fn render_accept_terms(
 429    state: Entity<State>,
 430    view_kind: LanguageModelProviderTosView,
 431    cx: &mut App,
 432) -> Option<AnyElement> {
 433    if state.read(cx).has_accepted_terms_of_service(cx) {
 434        return None;
 435    }
 436
 437    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
 438
 439    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
 440    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
 441
 442    let terms_button = Button::new("terms_of_service", "Terms of Service")
 443        .style(ButtonStyle::Subtle)
 444        .icon(IconName::ArrowUpRight)
 445        .icon_color(Color::Muted)
 446        .icon_size(IconSize::XSmall)
 447        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
 448        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
 449
 450    let button_container = h_flex().child(
 451        Button::new("accept_terms", "I accept the Terms of Service")
 452            .when(!thread_empty_state, |this| {
 453                this.full_width()
 454                    .style(ButtonStyle::Tinted(TintColor::Accent))
 455                    .icon(IconName::Check)
 456                    .icon_position(IconPosition::Start)
 457                    .icon_size(IconSize::Small)
 458            })
 459            .when(thread_empty_state, |this| {
 460                this.style(ButtonStyle::Tinted(TintColor::Warning))
 461                    .label_size(LabelSize::Small)
 462            })
 463            .disabled(accept_terms_disabled)
 464            .on_click({
 465                let state = state.downgrade();
 466                move |_, _window, cx| {
 467                    state
 468                        .update(cx, |state, cx| state.accept_terms_of_service(cx))
 469                        .ok();
 470                }
 471            }),
 472    );
 473
 474    let form = if thread_empty_state {
 475        h_flex()
 476            .w_full()
 477            .flex_wrap()
 478            .justify_between()
 479            .child(
 480                h_flex()
 481                    .child(
 482                        Label::new("To start using Zed AI, please read and accept the")
 483                            .size(LabelSize::Small),
 484                    )
 485                    .child(terms_button),
 486            )
 487            .child(button_container)
 488    } else {
 489        v_flex()
 490            .w_full()
 491            .gap_2()
 492            .child(
 493                h_flex()
 494                    .flex_wrap()
 495                    .when(thread_fresh_start, |this| this.justify_center())
 496                    .child(Label::new(
 497                        "To start using Zed AI, please read and accept the",
 498                    ))
 499                    .child(terms_button),
 500            )
 501            .child({
 502                match view_kind {
 503                    LanguageModelProviderTosView::PromptEditorPopup => {
 504                        button_container.w_full().justify_end()
 505                    }
 506                    LanguageModelProviderTosView::Configuration => {
 507                        button_container.w_full().justify_start()
 508                    }
 509                    LanguageModelProviderTosView::ThreadFreshStart => {
 510                        button_container.w_full().justify_center()
 511                    }
 512                    LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
 513                }
 514            })
 515    };
 516
 517    Some(form.into_any())
 518}
 519
 520pub struct CloudLanguageModel {
 521    id: LanguageModelId,
 522    model: Arc<zed_llm_client::LanguageModel>,
 523    llm_api_token: LlmApiToken,
 524    client: Arc<Client>,
 525    request_limiter: RateLimiter,
 526}
 527
 528struct PerformLlmCompletionResponse {
 529    response: Response<AsyncBody>,
 530    usage: Option<ModelRequestUsage>,
 531    tool_use_limit_reached: bool,
 532    includes_status_messages: bool,
 533}
 534
 535impl CloudLanguageModel {
 536    async fn perform_llm_completion(
 537        client: Arc<Client>,
 538        llm_api_token: LlmApiToken,
 539        app_version: Option<SemanticVersion>,
 540        body: CompletionBody,
 541    ) -> Result<PerformLlmCompletionResponse> {
 542        let http_client = &client.http_client();
 543
 544        let mut token = llm_api_token.acquire(&client).await?;
 545        let mut refreshed_token = false;
 546
 547        loop {
 548            let request_builder = http_client::Request::builder()
 549                .method(Method::POST)
 550                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
 551            let request_builder = if let Some(app_version) = app_version {
 552                request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 553            } else {
 554                request_builder
 555            };
 556
 557            let request = request_builder
 558                .header("Content-Type", "application/json")
 559                .header("Authorization", format!("Bearer {token}"))
 560                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 561                .body(serde_json::to_string(&body)?.into())?;
 562            let mut response = http_client.send(request).await?;
 563            let status = response.status();
 564            if status.is_success() {
 565                let includes_status_messages = response
 566                    .headers()
 567                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 568                    .is_some();
 569
 570                let tool_use_limit_reached = response
 571                    .headers()
 572                    .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
 573                    .is_some();
 574
 575                let usage = if includes_status_messages {
 576                    None
 577                } else {
 578                    ModelRequestUsage::from_headers(response.headers()).ok()
 579                };
 580
 581                return Ok(PerformLlmCompletionResponse {
 582                    response,
 583                    usage,
 584                    includes_status_messages,
 585                    tool_use_limit_reached,
 586                });
 587            }
 588
 589            if !refreshed_token
 590                && response
 591                    .headers()
 592                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 593                    .is_some()
 594            {
 595                token = llm_api_token.refresh(&client).await?;
 596                refreshed_token = true;
 597                continue;
 598            }
 599
 600            if status == StatusCode::FORBIDDEN
 601                && response
 602                    .headers()
 603                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 604                    .is_some()
 605            {
 606                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
 607                    .headers()
 608                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
 609                    .and_then(|resource| resource.to_str().ok())
 610                {
 611                    if let Some(plan) = response
 612                        .headers()
 613                        .get(CURRENT_PLAN_HEADER_NAME)
 614                        .and_then(|plan| plan.to_str().ok())
 615                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
 616                    {
 617                        let plan = match plan {
 618                            zed_llm_client::Plan::ZedFree => Plan::Free,
 619                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
 620                            zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
 621                        };
 622                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
 623                    }
 624                }
 625            } else if status == StatusCode::PAYMENT_REQUIRED {
 626                return Err(anyhow!(PaymentRequiredError));
 627            }
 628
 629            let mut body = String::new();
 630            let headers = response.headers().clone();
 631            response.body_mut().read_to_string(&mut body).await?;
 632            return Err(anyhow!(ApiError {
 633                status,
 634                body,
 635                headers
 636            }));
 637        }
 638    }
 639}
 640
 641#[derive(Debug, Error)]
 642#[error("cloud language model request failed with status {status}: {body}")]
 643struct ApiError {
 644    status: StatusCode,
 645    body: String,
 646    headers: HeaderMap<HeaderValue>,
 647}
 648
 649impl From<ApiError> for LanguageModelCompletionError {
 650    fn from(error: ApiError) -> Self {
 651        let retry_after = None;
 652        LanguageModelCompletionError::from_http_status(
 653            PROVIDER_NAME,
 654            error.status,
 655            error.body,
 656            retry_after,
 657        )
 658    }
 659}
 660
 661impl LanguageModel for CloudLanguageModel {
 662    fn id(&self) -> LanguageModelId {
 663        self.id.clone()
 664    }
 665
 666    fn name(&self) -> LanguageModelName {
 667        LanguageModelName::from(self.model.display_name.clone())
 668    }
 669
 670    fn provider_id(&self) -> LanguageModelProviderId {
 671        PROVIDER_ID
 672    }
 673
 674    fn provider_name(&self) -> LanguageModelProviderName {
 675        PROVIDER_NAME
 676    }
 677
 678    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 679        use zed_llm_client::LanguageModelProvider::*;
 680        match self.model.provider {
 681            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
 682            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
 683            Google => language_model::GOOGLE_PROVIDER_ID,
 684        }
 685    }
 686
 687    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 688        use zed_llm_client::LanguageModelProvider::*;
 689        match self.model.provider {
 690            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
 691            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
 692            Google => language_model::GOOGLE_PROVIDER_NAME,
 693        }
 694    }
 695
 696    fn supports_tools(&self) -> bool {
 697        self.model.supports_tools
 698    }
 699
 700    fn supports_images(&self) -> bool {
 701        self.model.supports_images
 702    }
 703
 704    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 705        match choice {
 706            LanguageModelToolChoice::Auto
 707            | LanguageModelToolChoice::Any
 708            | LanguageModelToolChoice::None => true,
 709        }
 710    }
 711
 712    fn supports_burn_mode(&self) -> bool {
 713        self.model.supports_max_mode
 714    }
 715
 716    fn telemetry_id(&self) -> String {
 717        format!("zed.dev/{}", self.model.id)
 718    }
 719
 720    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 721        match self.model.provider {
 722            zed_llm_client::LanguageModelProvider::Anthropic
 723            | zed_llm_client::LanguageModelProvider::OpenAi => {
 724                LanguageModelToolSchemaFormat::JsonSchema
 725            }
 726            zed_llm_client::LanguageModelProvider::Google => {
 727                LanguageModelToolSchemaFormat::JsonSchemaSubset
 728            }
 729        }
 730    }
 731
 732    fn max_token_count(&self) -> u64 {
 733        self.model.max_token_count as u64
 734    }
 735
 736    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 737        match &self.model.provider {
 738            zed_llm_client::LanguageModelProvider::Anthropic => {
 739                Some(LanguageModelCacheConfiguration {
 740                    min_total_token: 2_048,
 741                    should_speculate: true,
 742                    max_cache_anchors: 4,
 743                })
 744            }
 745            zed_llm_client::LanguageModelProvider::OpenAi
 746            | zed_llm_client::LanguageModelProvider::Google => None,
 747        }
 748    }
 749
 750    fn count_tokens(
 751        &self,
 752        request: LanguageModelRequest,
 753        cx: &App,
 754    ) -> BoxFuture<'static, Result<u64>> {
 755        match self.model.provider {
 756            zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
 757            zed_llm_client::LanguageModelProvider::OpenAi => {
 758                let model = match open_ai::Model::from_id(&self.model.id.0) {
 759                    Ok(model) => model,
 760                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 761                };
 762                count_open_ai_tokens(request, model, cx)
 763            }
 764            zed_llm_client::LanguageModelProvider::Google => {
 765                let client = self.client.clone();
 766                let llm_api_token = self.llm_api_token.clone();
 767                let model_id = self.model.id.to_string();
 768                let generate_content_request =
 769                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 770                async move {
 771                    let http_client = &client.http_client();
 772                    let token = llm_api_token.acquire(&client).await?;
 773
 774                    let request_body = CountTokensBody {
 775                        provider: zed_llm_client::LanguageModelProvider::Google,
 776                        model: model_id,
 777                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 778                            generate_content_request,
 779                        })?,
 780                    };
 781                    let request = http_client::Request::builder()
 782                        .method(Method::POST)
 783                        .uri(
 784                            http_client
 785                                .build_zed_llm_url("/count_tokens", &[])?
 786                                .as_ref(),
 787                        )
 788                        .header("Content-Type", "application/json")
 789                        .header("Authorization", format!("Bearer {token}"))
 790                        .body(serde_json::to_string(&request_body)?.into())?;
 791                    let mut response = http_client.send(request).await?;
 792                    let status = response.status();
 793                    let headers = response.headers().clone();
 794                    let mut response_body = String::new();
 795                    response
 796                        .body_mut()
 797                        .read_to_string(&mut response_body)
 798                        .await?;
 799
 800                    if status.is_success() {
 801                        let response_body: CountTokensResponse =
 802                            serde_json::from_str(&response_body)?;
 803
 804                        Ok(response_body.tokens as u64)
 805                    } else {
 806                        Err(anyhow!(ApiError {
 807                            status,
 808                            body: response_body,
 809                            headers
 810                        }))
 811                    }
 812                }
 813                .boxed()
 814            }
 815        }
 816    }
 817
 818    fn stream_completion(
 819        &self,
 820        request: LanguageModelRequest,
 821        cx: &AsyncApp,
 822    ) -> BoxFuture<
 823        'static,
 824        Result<
 825            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 826            LanguageModelCompletionError,
 827        >,
 828    > {
 829        let thread_id = request.thread_id.clone();
 830        let prompt_id = request.prompt_id.clone();
 831        let intent = request.intent;
 832        let mode = request.mode;
 833        let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
 834        match self.model.provider {
 835            zed_llm_client::LanguageModelProvider::Anthropic => {
 836                let request = into_anthropic(
 837                    request,
 838                    self.model.id.to_string(),
 839                    1.0,
 840                    self.model.max_output_tokens as u64,
 841                    if self.model.id.0.ends_with("-thinking") {
 842                        AnthropicModelMode::Thinking {
 843                            budget_tokens: Some(4_096),
 844                        }
 845                    } else {
 846                        AnthropicModelMode::Default
 847                    },
 848                );
 849                let client = self.client.clone();
 850                let llm_api_token = self.llm_api_token.clone();
 851                let future = self.request_limiter.stream(async move {
 852                    let PerformLlmCompletionResponse {
 853                        response,
 854                        usage,
 855                        includes_status_messages,
 856                        tool_use_limit_reached,
 857                    } = Self::perform_llm_completion(
 858                        client.clone(),
 859                        llm_api_token,
 860                        app_version,
 861                        CompletionBody {
 862                            thread_id,
 863                            prompt_id,
 864                            intent,
 865                            mode,
 866                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
 867                            model: request.model.clone(),
 868                            provider_request: serde_json::to_value(&request)
 869                                .map_err(|e| anyhow!(e))?,
 870                        },
 871                    )
 872                    .await
 873                    .map_err(|err| match err.downcast::<ApiError>() {
 874                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 875                        Err(err) => anyhow!(err),
 876                    })?;
 877
 878                    let mut mapper = AnthropicEventMapper::new();
 879                    Ok(map_cloud_completion_events(
 880                        Box::pin(
 881                            response_lines(response, includes_status_messages)
 882                                .chain(usage_updated_event(usage))
 883                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 884                        ),
 885                        move |event| mapper.map_event(event),
 886                    ))
 887                });
 888                async move { Ok(future.await?.boxed()) }.boxed()
 889            }
 890            zed_llm_client::LanguageModelProvider::OpenAi => {
 891                let client = self.client.clone();
 892                let model = match open_ai::Model::from_id(&self.model.id.0) {
 893                    Ok(model) => model,
 894                    Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
 895                };
 896                let request = into_open_ai(
 897                    request,
 898                    model.id(),
 899                    model.supports_parallel_tool_calls(),
 900                    None,
 901                );
 902                let llm_api_token = self.llm_api_token.clone();
 903                let future = self.request_limiter.stream(async move {
 904                    let PerformLlmCompletionResponse {
 905                        response,
 906                        usage,
 907                        includes_status_messages,
 908                        tool_use_limit_reached,
 909                    } = Self::perform_llm_completion(
 910                        client.clone(),
 911                        llm_api_token,
 912                        app_version,
 913                        CompletionBody {
 914                            thread_id,
 915                            prompt_id,
 916                            intent,
 917                            mode,
 918                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
 919                            model: request.model.clone(),
 920                            provider_request: serde_json::to_value(&request)
 921                                .map_err(|e| anyhow!(e))?,
 922                        },
 923                    )
 924                    .await?;
 925
 926                    let mut mapper = OpenAiEventMapper::new();
 927                    Ok(map_cloud_completion_events(
 928                        Box::pin(
 929                            response_lines(response, includes_status_messages)
 930                                .chain(usage_updated_event(usage))
 931                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 932                        ),
 933                        move |event| mapper.map_event(event),
 934                    ))
 935                });
 936                async move { Ok(future.await?.boxed()) }.boxed()
 937            }
 938            zed_llm_client::LanguageModelProvider::Google => {
 939                let client = self.client.clone();
 940                let request =
 941                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 942                let llm_api_token = self.llm_api_token.clone();
 943                let future = self.request_limiter.stream(async move {
 944                    let PerformLlmCompletionResponse {
 945                        response,
 946                        usage,
 947                        includes_status_messages,
 948                        tool_use_limit_reached,
 949                    } = Self::perform_llm_completion(
 950                        client.clone(),
 951                        llm_api_token,
 952                        app_version,
 953                        CompletionBody {
 954                            thread_id,
 955                            prompt_id,
 956                            intent,
 957                            mode,
 958                            provider: zed_llm_client::LanguageModelProvider::Google,
 959                            model: request.model.model_id.clone(),
 960                            provider_request: serde_json::to_value(&request)
 961                                .map_err(|e| anyhow!(e))?,
 962                        },
 963                    )
 964                    .await?;
 965
 966                    let mut mapper = GoogleEventMapper::new();
 967                    Ok(map_cloud_completion_events(
 968                        Box::pin(
 969                            response_lines(response, includes_status_messages)
 970                                .chain(usage_updated_event(usage))
 971                                .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
 972                        ),
 973                        move |event| mapper.map_event(event),
 974                    ))
 975                });
 976                async move { Ok(future.await?.boxed()) }.boxed()
 977            }
 978        }
 979    }
 980}
 981
 982#[derive(Serialize, Deserialize)]
 983#[serde(rename_all = "snake_case")]
 984pub enum CloudCompletionEvent<T> {
 985    Status(CompletionRequestStatus),
 986    Event(T),
 987}
 988
 989fn map_cloud_completion_events<T, F>(
 990    stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
 991    mut map_callback: F,
 992) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 993where
 994    T: DeserializeOwned + 'static,
 995    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 996        + Send
 997        + 'static,
 998{
 999    stream
1000        .flat_map(move |event| {
1001            futures::stream::iter(match event {
1002                Err(error) => {
1003                    vec![Err(LanguageModelCompletionError::from(error))]
1004                }
1005                Ok(CloudCompletionEvent::Status(event)) => {
1006                    vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1007                }
1008                Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1009            })
1010        })
1011        .boxed()
1012}
1013
1014fn usage_updated_event<T>(
1015    usage: Option<ModelRequestUsage>,
1016) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1017    futures::stream::iter(usage.map(|usage| {
1018        Ok(CloudCompletionEvent::Status(
1019            CompletionRequestStatus::UsageUpdated {
1020                amount: usage.amount as usize,
1021                limit: usage.limit,
1022            },
1023        ))
1024    }))
1025}
1026
1027fn tool_use_limit_reached_event<T>(
1028    tool_use_limit_reached: bool,
1029) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1030    futures::stream::iter(tool_use_limit_reached.then(|| {
1031        Ok(CloudCompletionEvent::Status(
1032            CompletionRequestStatus::ToolUseLimitReached,
1033        ))
1034    }))
1035}
1036
1037fn response_lines<T: DeserializeOwned>(
1038    response: Response<AsyncBody>,
1039    includes_status_messages: bool,
1040) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1041    futures::stream::try_unfold(
1042        (String::new(), BufReader::new(response.into_body())),
1043        move |(mut line, mut body)| async move {
1044            match body.read_line(&mut line).await {
1045                Ok(0) => Ok(None),
1046                Ok(_) => {
1047                    let event = if includes_status_messages {
1048                        serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1049                    } else {
1050                        CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1051                    };
1052
1053                    line.clear();
1054                    Ok(Some((event, (line, body))))
1055                }
1056                Err(e) => Err(e.into()),
1057            }
1058        },
1059    )
1060}
1061
1062struct ConfigurationView {
1063    state: gpui::Entity<State>,
1064}
1065
1066impl ConfigurationView {
1067    fn authenticate(&mut self, cx: &mut Context<Self>) {
1068        self.state.update(cx, |state, cx| {
1069            state.authenticate(cx).detach_and_log_err(cx);
1070        });
1071        cx.notify();
1072    }
1073}
1074
1075impl Render for ConfigurationView {
1076    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1077        const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1078
1079        let is_connected = !self.state.read(cx).is_signed_out();
1080        let user_store = self.state.read(cx).user_store.read(cx);
1081        let plan = user_store.current_plan();
1082        let subscription_period = user_store.subscription_period();
1083        let eligible_for_trial = user_store.trial_started_at().is_none();
1084        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1085
1086        let is_pro = plan == Some(proto::Plan::ZedPro);
1087        let subscription_text = match (plan, subscription_period) {
1088            (Some(proto::Plan::ZedPro), Some(_)) => {
1089                "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1090            }
1091            (Some(proto::Plan::ZedProTrial), Some(_)) => {
1092                "You have access to Zed's hosted LLMs through your Zed Pro trial."
1093            }
1094            (Some(proto::Plan::Free), Some(_)) => {
1095                "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1096            }
1097            _ => {
1098                if eligible_for_trial {
1099                    "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1100                } else {
1101                    "Subscribe for access to Zed's hosted LLMs."
1102                }
1103            }
1104        };
1105        let manage_subscription_buttons = if is_pro {
1106            h_flex().child(
1107                Button::new("manage_settings", "Manage Subscription")
1108                    .style(ButtonStyle::Tinted(TintColor::Accent))
1109                    .on_click(cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx)))),
1110            )
1111        } else {
1112            h_flex()
1113                .gap_2()
1114                .child(
1115                    Button::new("learn_more", "Learn more")
1116                        .style(ButtonStyle::Subtle)
1117                        .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_PRICING_URL))),
1118                )
1119                .child(
1120                    Button::new("upgrade", "Upgrade")
1121                        .style(ButtonStyle::Subtle)
1122                        .color(Color::Accent)
1123                        .on_click(
1124                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1125                        ),
1126                )
1127        };
1128
1129        if is_connected {
1130            v_flex()
1131                .gap_3()
1132                .w_full()
1133                .children(render_accept_terms(
1134                    self.state.clone(),
1135                    LanguageModelProviderTosView::Configuration,
1136                    cx,
1137                ))
1138                .when(has_accepted_terms, |this| {
1139                    this.child(subscription_text)
1140                        .child(manage_subscription_buttons)
1141                })
1142        } else {
1143            v_flex()
1144                .gap_2()
1145                .child(Label::new("Use Zed AI to access hosted language models."))
1146                .child(
1147                    Button::new("sign_in", "Sign In")
1148                        .icon_color(Color::Muted)
1149                        .icon(IconName::Github)
1150                        .icon_position(IconPosition::Start)
1151                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1152                )
1153        }
1154    }
1155}