cloud.rs

  1use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
  2use anyhow::{Result, anyhow};
  3use client::{Client, UserStore, zed_urls};
  4use collections::BTreeMap;
  5use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
  6use futures::{
  7    AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture,
  8    stream::BoxStream,
  9};
 10use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
 11use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
 12use language_model::{
 13    AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
 14    LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
 15    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
 16    LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
 17    ZED_CLOUD_PROVIDER_ID,
 18};
 19use language_model::{
 20    LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
 21    MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
 22};
 23use proto::Plan;
 24use schemars::JsonSchema;
 25use serde::{Deserialize, Serialize, de::DeserializeOwned};
 26use settings::{Settings, SettingsStore};
 27use smol::Timer;
 28use smol::io::{AsyncReadExt, BufReader};
 29use std::str::FromStr as _;
 30use std::{
 31    sync::{Arc, LazyLock},
 32    time::Duration,
 33};
 34use strum::IntoEnumIterator;
 35use thiserror::Error;
 36use ui::{TintColor, prelude::*};
 37use zed_llm_client::{
 38    CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, EXPIRED_LLM_TOKEN_HEADER_NAME,
 39    MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
 40    SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
 41};
 42
 43use crate::AllLanguageModelSettings;
 44use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic};
 45use crate::provider::google::into_google;
 46use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
 47
 48pub const PROVIDER_NAME: &str = "Zed";
 49
 50const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
 51    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
 52
 53fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
 54    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
 55        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
 56            .map(|json| serde_json::from_str(json).unwrap())
 57            .unwrap_or_default()
 58    });
 59    ADDITIONAL_MODELS.as_slice()
 60}
 61
 62#[derive(Default, Clone, Debug, PartialEq)]
 63pub struct ZedDotDevSettings {
 64    pub available_models: Vec<AvailableModel>,
 65}
 66
 67#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 68#[serde(rename_all = "lowercase")]
 69pub enum AvailableProvider {
 70    Anthropic,
 71    OpenAi,
 72    Google,
 73}
 74
 75#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 76pub struct AvailableModel {
 77    /// The provider of the language model.
 78    pub provider: AvailableProvider,
 79    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
 80    pub name: String,
 81    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 82    pub display_name: Option<String>,
 83    /// The size of the context window, indicating the maximum number of tokens the model can process.
 84    pub max_tokens: usize,
 85    /// The maximum number of output tokens allowed by the model.
 86    pub max_output_tokens: Option<u32>,
 87    /// The maximum number of completion tokens allowed by the model (o1-* only)
 88    pub max_completion_tokens: Option<u32>,
 89    /// Override this model with a different Anthropic model for tool calls.
 90    pub tool_override: Option<String>,
 91    /// Indicates whether this custom model supports caching.
 92    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 93    /// The default temperature to use for this model.
 94    pub default_temperature: Option<f32>,
 95    /// Any extra beta headers to provide when using the model.
 96    #[serde(default)]
 97    pub extra_beta_headers: Vec<String>,
 98    /// The model's mode (e.g. thinking)
 99    pub mode: Option<ModelMode>,
100}
101
102#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
103#[serde(tag = "type", rename_all = "lowercase")]
104pub enum ModelMode {
105    #[default]
106    Default,
107    Thinking {
108        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
109        budget_tokens: Option<u32>,
110    },
111}
112
113impl From<ModelMode> for AnthropicModelMode {
114    fn from(value: ModelMode) -> Self {
115        match value {
116            ModelMode::Default => AnthropicModelMode::Default,
117            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
118        }
119    }
120}
121
122pub struct CloudLanguageModelProvider {
123    client: Arc<Client>,
124    state: gpui::Entity<State>,
125    _maintain_client_status: Task<()>,
126}
127
128pub struct State {
129    client: Arc<Client>,
130    llm_api_token: LlmApiToken,
131    user_store: Entity<UserStore>,
132    status: client::Status,
133    accept_terms: Option<Task<Result<()>>>,
134    _settings_subscription: Subscription,
135    _llm_token_subscription: Subscription,
136}
137
138impl State {
139    fn new(
140        client: Arc<Client>,
141        user_store: Entity<UserStore>,
142        status: client::Status,
143        cx: &mut Context<Self>,
144    ) -> Self {
145        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
146
147        Self {
148            client: client.clone(),
149            llm_api_token: LlmApiToken::default(),
150            user_store,
151            status,
152            accept_terms: None,
153            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
154                cx.notify();
155            }),
156            _llm_token_subscription: cx.subscribe(
157                &refresh_llm_token_listener,
158                |this, _listener, _event, cx| {
159                    let client = this.client.clone();
160                    let llm_api_token = this.llm_api_token.clone();
161                    cx.spawn(async move |_this, _cx| {
162                        llm_api_token.refresh(&client).await?;
163                        anyhow::Ok(())
164                    })
165                    .detach_and_log_err(cx);
166                },
167            ),
168        }
169    }
170
171    fn is_signed_out(&self) -> bool {
172        self.status.is_signed_out()
173    }
174
175    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
176        let client = self.client.clone();
177        cx.spawn(async move |this, cx| {
178            client.authenticate_and_connect(true, &cx).await?;
179            this.update(cx, |_, cx| cx.notify())
180        })
181    }
182
183    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
184        self.user_store
185            .read(cx)
186            .current_user_has_accepted_terms()
187            .unwrap_or(false)
188    }
189
190    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
191        let user_store = self.user_store.clone();
192        self.accept_terms = Some(cx.spawn(async move |this, cx| {
193            let _ = user_store
194                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
195                .await;
196            this.update(cx, |this, cx| {
197                this.accept_terms = None;
198                cx.notify()
199            })
200        }));
201    }
202}
203
204impl CloudLanguageModelProvider {
205    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
206        let mut status_rx = client.status();
207        let status = *status_rx.borrow();
208
209        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
210
211        let state_ref = state.downgrade();
212        let maintain_client_status = cx.spawn(async move |cx| {
213            while let Some(status) = status_rx.next().await {
214                if let Some(this) = state_ref.upgrade() {
215                    _ = this.update(cx, |this, cx| {
216                        if this.status != status {
217                            this.status = status;
218                            cx.notify();
219                        }
220                    });
221                } else {
222                    break;
223                }
224            }
225        });
226
227        Self {
228            client,
229            state: state.clone(),
230            _maintain_client_status: maintain_client_status,
231        }
232    }
233
234    fn create_language_model(
235        &self,
236        model: CloudModel,
237        llm_api_token: LlmApiToken,
238    ) -> Arc<dyn LanguageModel> {
239        Arc::new(CloudLanguageModel {
240            id: LanguageModelId::from(model.id().to_string()),
241            model,
242            llm_api_token: llm_api_token.clone(),
243            client: self.client.clone(),
244            request_limiter: RateLimiter::new(4),
245        })
246    }
247}
248
249impl LanguageModelProviderState for CloudLanguageModelProvider {
250    type ObservableEntity = State;
251
252    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
253        Some(self.state.clone())
254    }
255}
256
257impl LanguageModelProvider for CloudLanguageModelProvider {
258    fn id(&self) -> LanguageModelProviderId {
259        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
260    }
261
262    fn name(&self) -> LanguageModelProviderName {
263        LanguageModelProviderName(PROVIDER_NAME.into())
264    }
265
266    fn icon(&self) -> IconName {
267        IconName::AiZed
268    }
269
270    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
271        let llm_api_token = self.state.read(cx).llm_api_token.clone();
272        let model = CloudModel::Anthropic(anthropic::Model::default());
273        Some(self.create_language_model(model, llm_api_token))
274    }
275
276    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
277        let llm_api_token = self.state.read(cx).llm_api_token.clone();
278        let model = CloudModel::Anthropic(anthropic::Model::default_fast());
279        Some(self.create_language_model(model, llm_api_token))
280    }
281
282    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
283        let llm_api_token = self.state.read(cx).llm_api_token.clone();
284        [
285            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
286            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
287        ]
288        .into_iter()
289        .map(|model| self.create_language_model(model, llm_api_token.clone()))
290        .collect()
291    }
292
293    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
294        let mut models = BTreeMap::default();
295
296        if cx.is_staff() {
297            for model in anthropic::Model::iter() {
298                if !matches!(model, anthropic::Model::Custom { .. }) {
299                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
300                }
301            }
302            for model in open_ai::Model::iter() {
303                if !matches!(model, open_ai::Model::Custom { .. }) {
304                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
305                }
306            }
307            for model in google_ai::Model::iter() {
308                if !matches!(model, google_ai::Model::Custom { .. }) {
309                    models.insert(model.id().to_string(), CloudModel::Google(model));
310                }
311            }
312        } else {
313            models.insert(
314                anthropic::Model::Claude3_5Sonnet.id().to_string(),
315                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
316            );
317            models.insert(
318                anthropic::Model::Claude3_7Sonnet.id().to_string(),
319                CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
320            );
321            models.insert(
322                anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
323                CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
324            );
325        }
326
327        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
328            zed_cloud_provider_additional_models()
329        } else {
330            &[]
331        };
332
333        // Override with available models from settings
334        for model in AllLanguageModelSettings::get_global(cx)
335            .zed_dot_dev
336            .available_models
337            .iter()
338            .chain(llm_closed_beta_models)
339            .cloned()
340        {
341            let model = match model.provider {
342                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
343                    name: model.name.clone(),
344                    display_name: model.display_name.clone(),
345                    max_tokens: model.max_tokens,
346                    tool_override: model.tool_override.clone(),
347                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
348                        anthropic::AnthropicModelCacheConfiguration {
349                            max_cache_anchors: config.max_cache_anchors,
350                            should_speculate: config.should_speculate,
351                            min_total_token: config.min_total_token,
352                        }
353                    }),
354                    default_temperature: model.default_temperature,
355                    max_output_tokens: model.max_output_tokens,
356                    extra_beta_headers: model.extra_beta_headers.clone(),
357                    mode: model.mode.unwrap_or_default().into(),
358                }),
359                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
360                    name: model.name.clone(),
361                    display_name: model.display_name.clone(),
362                    max_tokens: model.max_tokens,
363                    max_output_tokens: model.max_output_tokens,
364                    max_completion_tokens: model.max_completion_tokens,
365                }),
366                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
367                    name: model.name.clone(),
368                    display_name: model.display_name.clone(),
369                    max_tokens: model.max_tokens,
370                }),
371            };
372            models.insert(model.id().to_string(), model.clone());
373        }
374
375        let llm_api_token = self.state.read(cx).llm_api_token.clone();
376        models
377            .into_values()
378            .map(|model| self.create_language_model(model, llm_api_token.clone()))
379            .collect()
380    }
381
382    fn is_authenticated(&self, cx: &App) -> bool {
383        !self.state.read(cx).is_signed_out()
384    }
385
386    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
387        Task::ready(Ok(()))
388    }
389
390    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
391        cx.new(|_| ConfigurationView {
392            state: self.state.clone(),
393        })
394        .into()
395    }
396
397    fn must_accept_terms(&self, cx: &App) -> bool {
398        !self.state.read(cx).has_accepted_terms_of_service(cx)
399    }
400
401    fn render_accept_terms(
402        &self,
403        view: LanguageModelProviderTosView,
404        cx: &mut App,
405    ) -> Option<AnyElement> {
406        render_accept_terms(self.state.clone(), view, cx)
407    }
408
409    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
410        Task::ready(Ok(()))
411    }
412}
413
414fn render_accept_terms(
415    state: Entity<State>,
416    view_kind: LanguageModelProviderTosView,
417    cx: &mut App,
418) -> Option<AnyElement> {
419    if state.read(cx).has_accepted_terms_of_service(cx) {
420        return None;
421    }
422
423    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
424
425    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
426    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
427
428    let terms_button = Button::new("terms_of_service", "Terms of Service")
429        .style(ButtonStyle::Subtle)
430        .icon(IconName::ArrowUpRight)
431        .icon_color(Color::Muted)
432        .icon_size(IconSize::XSmall)
433        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
434        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
435
436    let button_container = h_flex().child(
437        Button::new("accept_terms", "I accept the Terms of Service")
438            .when(!thread_empty_state, |this| {
439                this.full_width()
440                    .style(ButtonStyle::Tinted(TintColor::Accent))
441                    .icon(IconName::Check)
442                    .icon_position(IconPosition::Start)
443                    .icon_size(IconSize::Small)
444            })
445            .when(thread_empty_state, |this| {
446                this.style(ButtonStyle::Tinted(TintColor::Warning))
447                    .label_size(LabelSize::Small)
448            })
449            .disabled(accept_terms_disabled)
450            .on_click({
451                let state = state.downgrade();
452                move |_, _window, cx| {
453                    state
454                        .update(cx, |state, cx| state.accept_terms_of_service(cx))
455                        .ok();
456                }
457            }),
458    );
459
460    let form = if thread_empty_state {
461        h_flex()
462            .w_full()
463            .flex_wrap()
464            .justify_between()
465            .child(
466                h_flex()
467                    .child(
468                        Label::new("To start using Zed AI, please read and accept the")
469                            .size(LabelSize::Small),
470                    )
471                    .child(terms_button),
472            )
473            .child(button_container)
474    } else {
475        v_flex()
476            .w_full()
477            .gap_2()
478            .child(
479                h_flex()
480                    .flex_wrap()
481                    .when(thread_fresh_start, |this| this.justify_center())
482                    .child(Label::new(
483                        "To start using Zed AI, please read and accept the",
484                    ))
485                    .child(terms_button),
486            )
487            .child({
488                match view_kind {
489                    LanguageModelProviderTosView::PromptEditorPopup => {
490                        button_container.w_full().justify_end()
491                    }
492                    LanguageModelProviderTosView::Configuration => {
493                        button_container.w_full().justify_start()
494                    }
495                    LanguageModelProviderTosView::ThreadFreshStart => {
496                        button_container.w_full().justify_center()
497                    }
498                    LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
499                }
500            })
501    };
502
503    Some(form.into_any())
504}
505
506pub struct CloudLanguageModel {
507    id: LanguageModelId,
508    model: CloudModel,
509    llm_api_token: LlmApiToken,
510    client: Arc<Client>,
511    request_limiter: RateLimiter,
512}
513
514impl CloudLanguageModel {
515    const MAX_RETRIES: usize = 3;
516
517    async fn perform_llm_completion(
518        client: Arc<Client>,
519        llm_api_token: LlmApiToken,
520        body: CompletionBody,
521    ) -> Result<(Response<AsyncBody>, Option<RequestUsage>)> {
522        let http_client = &client.http_client();
523
524        let mut token = llm_api_token.acquire(&client).await?;
525        let mut retries_remaining = Self::MAX_RETRIES;
526        let mut retry_delay = Duration::from_secs(1);
527
528        loop {
529            let request_builder = http_client::Request::builder().method(Method::POST);
530            let request_builder = if let Ok(completions_url) = std::env::var("ZED_COMPLETIONS_URL")
531            {
532                request_builder.uri(completions_url)
533            } else {
534                request_builder.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
535            };
536            let request = request_builder
537                .header("Content-Type", "application/json")
538                .header("Authorization", format!("Bearer {token}"))
539                .body(serde_json::to_string(&body)?.into())?;
540            let mut response = http_client.send(request).await?;
541            let status = response.status();
542            if status.is_success() {
543                let usage = RequestUsage::from_headers(response.headers()).ok();
544
545                return Ok((response, usage));
546            } else if response
547                .headers()
548                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
549                .is_some()
550            {
551                retries_remaining -= 1;
552                token = llm_api_token.refresh(&client).await?;
553            } else if status == StatusCode::FORBIDDEN
554                && response
555                    .headers()
556                    .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
557                    .is_some()
558            {
559                return Err(anyhow!(MaxMonthlySpendReachedError));
560            } else if status == StatusCode::FORBIDDEN
561                && response
562                    .headers()
563                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
564                    .is_some()
565            {
566                if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
567                    .headers()
568                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
569                    .and_then(|resource| resource.to_str().ok())
570                {
571                    if let Some(plan) = response
572                        .headers()
573                        .get(CURRENT_PLAN_HEADER_NAME)
574                        .and_then(|plan| plan.to_str().ok())
575                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
576                    {
577                        let plan = match plan {
578                            zed_llm_client::Plan::Free => Plan::Free,
579                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
580                            zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
581                        };
582                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
583                    }
584                }
585
586                return Err(anyhow!("Forbidden"));
587            } else if status.as_u16() >= 500 && status.as_u16() < 600 {
588                // If we encounter an error in the 500 range, retry after a delay.
589                // We've seen at least these in the wild from API providers:
590                // * 500 Internal Server Error
591                // * 502 Bad Gateway
592                // * 529 Service Overloaded
593
594                if retries_remaining == 0 {
595                    let mut body = String::new();
596                    response.body_mut().read_to_string(&mut body).await?;
597                    return Err(anyhow!(
598                        "cloud language model completion failed after {} retries with status {status}: {body}",
599                        Self::MAX_RETRIES
600                    ));
601                }
602
603                Timer::after(retry_delay).await;
604
605                retries_remaining -= 1;
606                retry_delay *= 2; // If it fails again, wait longer.
607            } else if status == StatusCode::PAYMENT_REQUIRED {
608                return Err(anyhow!(PaymentRequiredError));
609            } else {
610                let mut body = String::new();
611                response.body_mut().read_to_string(&mut body).await?;
612                return Err(anyhow!(ApiError { status, body }));
613            }
614        }
615    }
616}
617
618#[derive(Debug, Error)]
619#[error("cloud language model completion failed with status {status}: {body}")]
620struct ApiError {
621    status: StatusCode,
622    body: String,
623}
624
625impl LanguageModel for CloudLanguageModel {
626    fn id(&self) -> LanguageModelId {
627        self.id.clone()
628    }
629
630    fn name(&self) -> LanguageModelName {
631        LanguageModelName::from(self.model.display_name().to_string())
632    }
633
634    fn provider_id(&self) -> LanguageModelProviderId {
635        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
636    }
637
638    fn provider_name(&self) -> LanguageModelProviderName {
639        LanguageModelProviderName(PROVIDER_NAME.into())
640    }
641
642    fn supports_tools(&self) -> bool {
643        match self.model {
644            CloudModel::Anthropic(_) => true,
645            CloudModel::Google(_) => true,
646            CloudModel::OpenAi(_) => true,
647        }
648    }
649
650    fn telemetry_id(&self) -> String {
651        format!("zed.dev/{}", self.model.id())
652    }
653
654    fn availability(&self) -> LanguageModelAvailability {
655        self.model.availability()
656    }
657
658    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
659        self.model.tool_input_format()
660    }
661
662    fn max_token_count(&self) -> usize {
663        self.model.max_token_count()
664    }
665
666    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
667        match &self.model {
668            CloudModel::Anthropic(model) => {
669                model
670                    .cache_configuration()
671                    .map(|cache| LanguageModelCacheConfiguration {
672                        max_cache_anchors: cache.max_cache_anchors,
673                        should_speculate: cache.should_speculate,
674                        min_total_token: cache.min_total_token,
675                    })
676            }
677            CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
678        }
679    }
680
681    fn count_tokens(
682        &self,
683        request: LanguageModelRequest,
684        cx: &App,
685    ) -> BoxFuture<'static, Result<usize>> {
686        match self.model.clone() {
687            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
688            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
689            CloudModel::Google(model) => {
690                let client = self.client.clone();
691                let request = into_google(request, model.id().into());
692                let request = google_ai::CountTokensRequest {
693                    contents: request.contents,
694                };
695                async move {
696                    let request = serde_json::to_string(&request)?;
697                    let response = client
698                        .request(proto::CountLanguageModelTokens {
699                            provider: proto::LanguageModelProvider::Google as i32,
700                            request,
701                        })
702                        .await?;
703                    Ok(response.token_count as usize)
704                }
705                .boxed()
706            }
707        }
708    }
709
710    fn stream_completion(
711        &self,
712        request: LanguageModelRequest,
713        cx: &AsyncApp,
714    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
715        self.stream_completion_with_usage(request, cx)
716            .map(|result| result.map(|(stream, _)| stream))
717            .boxed()
718    }
719
720    fn stream_completion_with_usage(
721        &self,
722        request: LanguageModelRequest,
723        _cx: &AsyncApp,
724    ) -> BoxFuture<
725        'static,
726        Result<(
727            BoxStream<'static, Result<LanguageModelCompletionEvent>>,
728            Option<RequestUsage>,
729        )>,
730    > {
731        let thread_id = request.thread_id.clone();
732        let prompt_id = request.prompt_id.clone();
733        match &self.model {
734            CloudModel::Anthropic(model) => {
735                let request = into_anthropic(
736                    request,
737                    model.request_id().into(),
738                    model.default_temperature(),
739                    model.max_output_tokens(),
740                    model.mode(),
741                );
742                let client = self.client.clone();
743                let llm_api_token = self.llm_api_token.clone();
744                let future = self.request_limiter.stream_with_usage(async move {
745                    let (response, usage) = Self::perform_llm_completion(
746                        client.clone(),
747                        llm_api_token,
748                        CompletionBody {
749                            thread_id,
750                            prompt_id,
751                            mode: Some(CompletionMode::Max),
752                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
753                            model: request.model.clone(),
754                            provider_request: serde_json::to_value(&request)?,
755                        },
756                    )
757                    .await
758                    .map_err(|err| match err.downcast::<ApiError>() {
759                        Ok(api_err) => {
760                            if api_err.status == StatusCode::BAD_REQUEST {
761                                if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
762                                    return anyhow!(
763                                        LanguageModelKnownError::ContextWindowLimitExceeded {
764                                            tokens
765                                        }
766                                    );
767                                }
768                            }
769                            anyhow!(api_err)
770                        }
771                        Err(err) => anyhow!(err),
772                    })?;
773
774                    Ok((
775                        crate::provider::anthropic::map_to_language_model_completion_events(
776                            Box::pin(response_lines(response).map_err(AnthropicError::Other)),
777                        ),
778                        usage,
779                    ))
780                });
781                async move {
782                    let (stream, usage) = future.await?;
783                    Ok((stream.boxed(), usage))
784                }
785                .boxed()
786            }
787            CloudModel::OpenAi(model) => {
788                let client = self.client.clone();
789                let request = into_open_ai(request, model, model.max_output_tokens());
790                let llm_api_token = self.llm_api_token.clone();
791                let future = self.request_limiter.stream_with_usage(async move {
792                    let (response, usage) = Self::perform_llm_completion(
793                        client.clone(),
794                        llm_api_token,
795                        CompletionBody {
796                            thread_id,
797                            prompt_id,
798                            mode: Some(CompletionMode::Max),
799                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
800                            model: request.model.clone(),
801                            provider_request: serde_json::to_value(&request)?,
802                        },
803                    )
804                    .await?;
805                    Ok((
806                        crate::provider::open_ai::map_to_language_model_completion_events(
807                            Box::pin(response_lines(response)),
808                        ),
809                        usage,
810                    ))
811                });
812                async move {
813                    let (stream, usage) = future.await?;
814                    Ok((stream.boxed(), usage))
815                }
816                .boxed()
817            }
818            CloudModel::Google(model) => {
819                let client = self.client.clone();
820                let request = into_google(request, model.id().into());
821                let llm_api_token = self.llm_api_token.clone();
822                let future = self.request_limiter.stream_with_usage(async move {
823                    let (response, usage) = Self::perform_llm_completion(
824                        client.clone(),
825                        llm_api_token,
826                        CompletionBody {
827                            thread_id,
828                            prompt_id,
829                            mode: Some(CompletionMode::Max),
830                            provider: zed_llm_client::LanguageModelProvider::Google,
831                            model: request.model.clone(),
832                            provider_request: serde_json::to_value(&request)?,
833                        },
834                    )
835                    .await?;
836                    Ok((
837                        crate::provider::google::map_to_language_model_completion_events(Box::pin(
838                            response_lines(response),
839                        )),
840                        usage,
841                    ))
842                });
843                async move {
844                    let (stream, usage) = future.await?;
845                    Ok((stream.boxed(), usage))
846                }
847                .boxed()
848            }
849        }
850    }
851}
852
853fn response_lines<T: DeserializeOwned>(
854    response: Response<AsyncBody>,
855) -> impl Stream<Item = Result<T>> {
856    futures::stream::try_unfold(
857        (String::new(), BufReader::new(response.into_body())),
858        move |(mut line, mut body)| async {
859            match body.read_line(&mut line).await {
860                Ok(0) => Ok(None),
861                Ok(_) => {
862                    let event: T = serde_json::from_str(&line)?;
863                    line.clear();
864                    Ok(Some((event, (line, body))))
865                }
866                Err(e) => Err(e.into()),
867            }
868        },
869    )
870}
871
872struct ConfigurationView {
873    state: gpui::Entity<State>,
874}
875
876impl ConfigurationView {
877    fn authenticate(&mut self, cx: &mut Context<Self>) {
878        self.state.update(cx, |state, cx| {
879            state.authenticate(cx).detach_and_log_err(cx);
880        });
881        cx.notify();
882    }
883}
884
885impl Render for ConfigurationView {
886    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
887        const ZED_AI_URL: &str = "https://zed.dev/ai";
888
889        let is_connected = !self.state.read(cx).is_signed_out();
890        let plan = self.state.read(cx).user_store.read(cx).current_plan();
891        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
892
893        let is_pro = plan == Some(proto::Plan::ZedPro);
894        let subscription_text = Label::new(if is_pro {
895            "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
896        } else {
897            "You have basic access to models from Anthropic through the Zed AI Free plan."
898        });
899        let manage_subscription_button = if is_pro {
900            Some(
901                h_flex().child(
902                    Button::new("manage_settings", "Manage Subscription")
903                        .style(ButtonStyle::Tinted(TintColor::Accent))
904                        .on_click(
905                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
906                        ),
907                ),
908            )
909        } else if cx.has_flag::<ZedPro>() {
910            Some(
911                h_flex()
912                    .gap_2()
913                    .child(
914                        Button::new("learn_more", "Learn more")
915                            .style(ButtonStyle::Subtle)
916                            .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_AI_URL))),
917                    )
918                    .child(
919                        Button::new("upgrade", "Upgrade")
920                            .style(ButtonStyle::Subtle)
921                            .color(Color::Accent)
922                            .on_click(
923                                cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
924                            ),
925                    ),
926            )
927        } else {
928            None
929        };
930
931        if is_connected {
932            v_flex()
933                .gap_3()
934                .w_full()
935                .children(render_accept_terms(
936                    self.state.clone(),
937                    LanguageModelProviderTosView::Configuration,
938                    cx,
939                ))
940                .when(has_accepted_terms, |this| {
941                    this.child(subscription_text)
942                        .children(manage_subscription_button)
943                })
944        } else {
945            v_flex()
946                .gap_2()
947                .child(Label::new("Use Zed AI to access hosted language models."))
948                .child(
949                    Button::new("sign_in", "Sign In")
950                        .icon_color(Color::Muted)
951                        .icon(IconName::Github)
952                        .icon_position(IconPosition::Start)
953                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
954                )
955        }
956    }
957}