cloud.rs

  1use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
  2use anyhow::{Result, anyhow};
  3use client::{
  4    Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
  5    PerformCompletionParams, UserStore, zed_urls,
  6};
  7use collections::BTreeMap;
  8use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
  9use futures::{
 10    AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture,
 11    stream::BoxStream,
 12};
 13use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
 14use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
 15use language_model::{
 16    AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
 17    LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
 18    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
 19    LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
 20    ZED_CLOUD_PROVIDER_ID,
 21};
 22use language_model::{
 23    LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
 24    MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
 25};
 26use proto::Plan;
 27use schemars::JsonSchema;
 28use serde::{Deserialize, Serialize, de::DeserializeOwned};
 29use serde_json::value::RawValue;
 30use settings::{Settings, SettingsStore};
 31use smol::Timer;
 32use smol::io::{AsyncReadExt, BufReader};
 33use std::str::FromStr as _;
 34use std::{
 35    sync::{Arc, LazyLock},
 36    time::Duration,
 37};
 38use strum::IntoEnumIterator;
 39use thiserror::Error;
 40use ui::{TintColor, prelude::*};
 41use zed_llm_client::{CURRENT_PLAN_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME};
 42
 43use crate::AllLanguageModelSettings;
 44use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic};
 45use crate::provider::google::into_google;
 46use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
 47
 48pub const PROVIDER_NAME: &str = "Zed";
 49
 50const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
 51    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
 52
 53fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
 54    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
 55        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
 56            .map(|json| serde_json::from_str(json).unwrap())
 57            .unwrap_or_default()
 58    });
 59    ADDITIONAL_MODELS.as_slice()
 60}
 61
 62#[derive(Default, Clone, Debug, PartialEq)]
 63pub struct ZedDotDevSettings {
 64    pub available_models: Vec<AvailableModel>,
 65}
 66
 67#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 68#[serde(rename_all = "lowercase")]
 69pub enum AvailableProvider {
 70    Anthropic,
 71    OpenAi,
 72    Google,
 73}
 74
 75#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 76pub struct AvailableModel {
 77    /// The provider of the language model.
 78    pub provider: AvailableProvider,
 79    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
 80    pub name: String,
 81    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 82    pub display_name: Option<String>,
 83    /// The size of the context window, indicating the maximum number of tokens the model can process.
 84    pub max_tokens: usize,
 85    /// The maximum number of output tokens allowed by the model.
 86    pub max_output_tokens: Option<u32>,
 87    /// The maximum number of completion tokens allowed by the model (o1-* only)
 88    pub max_completion_tokens: Option<u32>,
 89    /// Override this model with a different Anthropic model for tool calls.
 90    pub tool_override: Option<String>,
 91    /// Indicates whether this custom model supports caching.
 92    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 93    /// The default temperature to use for this model.
 94    pub default_temperature: Option<f32>,
 95    /// Any extra beta headers to provide when using the model.
 96    #[serde(default)]
 97    pub extra_beta_headers: Vec<String>,
 98    /// The model's mode (e.g. thinking)
 99    pub mode: Option<ModelMode>,
100}
101
102#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
103#[serde(tag = "type", rename_all = "lowercase")]
104pub enum ModelMode {
105    #[default]
106    Default,
107    Thinking {
108        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
109        budget_tokens: Option<u32>,
110    },
111}
112
113impl From<ModelMode> for AnthropicModelMode {
114    fn from(value: ModelMode) -> Self {
115        match value {
116            ModelMode::Default => AnthropicModelMode::Default,
117            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
118        }
119    }
120}
121
122pub struct CloudLanguageModelProvider {
123    client: Arc<Client>,
124    state: gpui::Entity<State>,
125    _maintain_client_status: Task<()>,
126}
127
128pub struct State {
129    client: Arc<Client>,
130    llm_api_token: LlmApiToken,
131    user_store: Entity<UserStore>,
132    status: client::Status,
133    accept_terms: Option<Task<Result<()>>>,
134    _settings_subscription: Subscription,
135    _llm_token_subscription: Subscription,
136}
137
138impl State {
139    fn new(
140        client: Arc<Client>,
141        user_store: Entity<UserStore>,
142        status: client::Status,
143        cx: &mut Context<Self>,
144    ) -> Self {
145        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
146
147        Self {
148            client: client.clone(),
149            llm_api_token: LlmApiToken::default(),
150            user_store,
151            status,
152            accept_terms: None,
153            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
154                cx.notify();
155            }),
156            _llm_token_subscription: cx.subscribe(
157                &refresh_llm_token_listener,
158                |this, _listener, _event, cx| {
159                    let client = this.client.clone();
160                    let llm_api_token = this.llm_api_token.clone();
161                    cx.spawn(async move |_this, _cx| {
162                        llm_api_token.refresh(&client).await?;
163                        anyhow::Ok(())
164                    })
165                    .detach_and_log_err(cx);
166                },
167            ),
168        }
169    }
170
171    fn is_signed_out(&self) -> bool {
172        self.status.is_signed_out()
173    }
174
175    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
176        let client = self.client.clone();
177        cx.spawn(async move |this, cx| {
178            client.authenticate_and_connect(true, &cx).await?;
179            this.update(cx, |_, cx| cx.notify())
180        })
181    }
182
183    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
184        self.user_store
185            .read(cx)
186            .current_user_has_accepted_terms()
187            .unwrap_or(false)
188    }
189
190    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
191        let user_store = self.user_store.clone();
192        self.accept_terms = Some(cx.spawn(async move |this, cx| {
193            let _ = user_store
194                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
195                .await;
196            this.update(cx, |this, cx| {
197                this.accept_terms = None;
198                cx.notify()
199            })
200        }));
201    }
202}
203
204impl CloudLanguageModelProvider {
205    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
206        let mut status_rx = client.status();
207        let status = *status_rx.borrow();
208
209        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
210
211        let state_ref = state.downgrade();
212        let maintain_client_status = cx.spawn(async move |cx| {
213            while let Some(status) = status_rx.next().await {
214                if let Some(this) = state_ref.upgrade() {
215                    _ = this.update(cx, |this, cx| {
216                        if this.status != status {
217                            this.status = status;
218                            cx.notify();
219                        }
220                    });
221                } else {
222                    break;
223                }
224            }
225        });
226
227        Self {
228            client,
229            state: state.clone(),
230            _maintain_client_status: maintain_client_status,
231        }
232    }
233
234    fn create_language_model(
235        &self,
236        model: CloudModel,
237        llm_api_token: LlmApiToken,
238    ) -> Arc<dyn LanguageModel> {
239        Arc::new(CloudLanguageModel {
240            id: LanguageModelId::from(model.id().to_string()),
241            model,
242            llm_api_token: llm_api_token.clone(),
243            client: self.client.clone(),
244            request_limiter: RateLimiter::new(4),
245        }) as Arc<dyn LanguageModel>
246    }
247}
248
249impl LanguageModelProviderState for CloudLanguageModelProvider {
250    type ObservableEntity = State;
251
252    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
253        Some(self.state.clone())
254    }
255}
256
257impl LanguageModelProvider for CloudLanguageModelProvider {
258    fn id(&self) -> LanguageModelProviderId {
259        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
260    }
261
262    fn name(&self) -> LanguageModelProviderName {
263        LanguageModelProviderName(PROVIDER_NAME.into())
264    }
265
266    fn icon(&self) -> IconName {
267        IconName::AiZed
268    }
269
270    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
271        let llm_api_token = self.state.read(cx).llm_api_token.clone();
272        let model = CloudModel::Anthropic(anthropic::Model::default());
273        Some(Arc::new(CloudLanguageModel {
274            id: LanguageModelId::from(model.id().to_string()),
275            model,
276            llm_api_token: llm_api_token.clone(),
277            client: self.client.clone(),
278            request_limiter: RateLimiter::new(4),
279        }))
280    }
281
282    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
283        let llm_api_token = self.state.read(cx).llm_api_token.clone();
284        [
285            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
286            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
287        ]
288        .into_iter()
289        .map(|model| self.create_language_model(model, llm_api_token.clone()))
290        .collect()
291    }
292
293    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
294        let mut models = BTreeMap::default();
295
296        if cx.is_staff() {
297            for model in anthropic::Model::iter() {
298                if !matches!(model, anthropic::Model::Custom { .. }) {
299                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
300                }
301            }
302            for model in open_ai::Model::iter() {
303                if !matches!(model, open_ai::Model::Custom { .. }) {
304                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
305                }
306            }
307            for model in google_ai::Model::iter() {
308                if !matches!(model, google_ai::Model::Custom { .. }) {
309                    models.insert(model.id().to_string(), CloudModel::Google(model));
310                }
311            }
312        } else {
313            models.insert(
314                anthropic::Model::Claude3_5Sonnet.id().to_string(),
315                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
316            );
317            models.insert(
318                anthropic::Model::Claude3_7Sonnet.id().to_string(),
319                CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
320            );
321            models.insert(
322                anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
323                CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
324            );
325        }
326
327        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
328            zed_cloud_provider_additional_models()
329        } else {
330            &[]
331        };
332
333        // Override with available models from settings
334        for model in AllLanguageModelSettings::get_global(cx)
335            .zed_dot_dev
336            .available_models
337            .iter()
338            .chain(llm_closed_beta_models)
339            .cloned()
340        {
341            let model = match model.provider {
342                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
343                    name: model.name.clone(),
344                    display_name: model.display_name.clone(),
345                    max_tokens: model.max_tokens,
346                    tool_override: model.tool_override.clone(),
347                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
348                        anthropic::AnthropicModelCacheConfiguration {
349                            max_cache_anchors: config.max_cache_anchors,
350                            should_speculate: config.should_speculate,
351                            min_total_token: config.min_total_token,
352                        }
353                    }),
354                    default_temperature: model.default_temperature,
355                    max_output_tokens: model.max_output_tokens,
356                    extra_beta_headers: model.extra_beta_headers.clone(),
357                    mode: model.mode.unwrap_or_default().into(),
358                }),
359                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
360                    name: model.name.clone(),
361                    display_name: model.display_name.clone(),
362                    max_tokens: model.max_tokens,
363                    max_output_tokens: model.max_output_tokens,
364                    max_completion_tokens: model.max_completion_tokens,
365                }),
366                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
367                    name: model.name.clone(),
368                    display_name: model.display_name.clone(),
369                    max_tokens: model.max_tokens,
370                }),
371            };
372            models.insert(model.id().to_string(), model.clone());
373        }
374
375        let llm_api_token = self.state.read(cx).llm_api_token.clone();
376        models
377            .into_values()
378            .map(|model| self.create_language_model(model, llm_api_token.clone()))
379            .collect()
380    }
381
382    fn is_authenticated(&self, cx: &App) -> bool {
383        !self.state.read(cx).is_signed_out()
384    }
385
386    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
387        Task::ready(Ok(()))
388    }
389
390    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
391        cx.new(|_| ConfigurationView {
392            state: self.state.clone(),
393        })
394        .into()
395    }
396
397    fn must_accept_terms(&self, cx: &App) -> bool {
398        !self.state.read(cx).has_accepted_terms_of_service(cx)
399    }
400
401    fn render_accept_terms(
402        &self,
403        view: LanguageModelProviderTosView,
404        cx: &mut App,
405    ) -> Option<AnyElement> {
406        render_accept_terms(self.state.clone(), view, cx)
407    }
408
409    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
410        Task::ready(Ok(()))
411    }
412}
413
414fn render_accept_terms(
415    state: Entity<State>,
416    view_kind: LanguageModelProviderTosView,
417    cx: &mut App,
418) -> Option<AnyElement> {
419    if state.read(cx).has_accepted_terms_of_service(cx) {
420        return None;
421    }
422
423    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
424
425    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
426    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
427
428    let terms_button = Button::new("terms_of_service", "Terms of Service")
429        .style(ButtonStyle::Subtle)
430        .icon(IconName::ArrowUpRight)
431        .icon_color(Color::Muted)
432        .icon_size(IconSize::XSmall)
433        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
434        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
435
436    let button_container = h_flex().child(
437        Button::new("accept_terms", "I accept the Terms of Service")
438            .when(!thread_empty_state, |this| {
439                this.full_width()
440                    .style(ButtonStyle::Tinted(TintColor::Accent))
441                    .icon(IconName::Check)
442                    .icon_position(IconPosition::Start)
443                    .icon_size(IconSize::Small)
444            })
445            .when(thread_empty_state, |this| {
446                this.style(ButtonStyle::Tinted(TintColor::Warning))
447                    .label_size(LabelSize::Small)
448            })
449            .disabled(accept_terms_disabled)
450            .on_click({
451                let state = state.downgrade();
452                move |_, _window, cx| {
453                    state
454                        .update(cx, |state, cx| state.accept_terms_of_service(cx))
455                        .ok();
456                }
457            }),
458    );
459
460    let form = if thread_empty_state {
461        h_flex()
462            .w_full()
463            .flex_wrap()
464            .justify_between()
465            .child(
466                h_flex()
467                    .child(
468                        Label::new("To start using Zed AI, please read and accept the")
469                            .size(LabelSize::Small),
470                    )
471                    .child(terms_button),
472            )
473            .child(button_container)
474    } else {
475        v_flex()
476            .w_full()
477            .gap_2()
478            .child(
479                h_flex()
480                    .flex_wrap()
481                    .when(thread_fresh_start, |this| this.justify_center())
482                    .child(Label::new(
483                        "To start using Zed AI, please read and accept the",
484                    ))
485                    .child(terms_button),
486            )
487            .child({
488                match view_kind {
489                    LanguageModelProviderTosView::PromptEditorPopup => {
490                        button_container.w_full().justify_end()
491                    }
492                    LanguageModelProviderTosView::Configuration => {
493                        button_container.w_full().justify_start()
494                    }
495                    LanguageModelProviderTosView::ThreadFreshStart => {
496                        button_container.w_full().justify_center()
497                    }
498                    LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
499                }
500            })
501    };
502
503    Some(form.into_any())
504}
505
506pub struct CloudLanguageModel {
507    id: LanguageModelId,
508    model: CloudModel,
509    llm_api_token: LlmApiToken,
510    client: Arc<Client>,
511    request_limiter: RateLimiter,
512}
513
514impl CloudLanguageModel {
515    const MAX_RETRIES: usize = 3;
516
517    async fn perform_llm_completion(
518        client: Arc<Client>,
519        llm_api_token: LlmApiToken,
520        body: PerformCompletionParams,
521    ) -> Result<Response<AsyncBody>> {
522        let http_client = &client.http_client();
523
524        let mut token = llm_api_token.acquire(&client).await?;
525        let mut retries_remaining = Self::MAX_RETRIES;
526        let mut retry_delay = Duration::from_secs(1);
527
528        loop {
529            let request_builder = http_client::Request::builder().method(Method::POST);
530            let request_builder = if let Ok(completions_url) = std::env::var("ZED_COMPLETIONS_URL")
531            {
532                request_builder.uri(completions_url)
533            } else {
534                request_builder.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
535            };
536            let request = request_builder
537                .header("Content-Type", "application/json")
538                .header("Authorization", format!("Bearer {token}"))
539                .body(serde_json::to_string(&body)?.into())?;
540            let mut response = http_client.send(request).await?;
541            let status = response.status();
542            if status.is_success() {
543                return Ok(response);
544            } else if response
545                .headers()
546                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
547                .is_some()
548            {
549                retries_remaining -= 1;
550                token = llm_api_token.refresh(&client).await?;
551            } else if status == StatusCode::FORBIDDEN
552                && response
553                    .headers()
554                    .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
555                    .is_some()
556            {
557                return Err(anyhow!(MaxMonthlySpendReachedError));
558            } else if status == StatusCode::FORBIDDEN
559                && response
560                    .headers()
561                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
562                    .is_some()
563            {
564                if let Some("model_requests") = response
565                    .headers()
566                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
567                    .and_then(|resource| resource.to_str().ok())
568                {
569                    if let Some(plan) = response
570                        .headers()
571                        .get(CURRENT_PLAN_HEADER_NAME)
572                        .and_then(|plan| plan.to_str().ok())
573                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
574                    {
575                        let plan = match plan {
576                            zed_llm_client::Plan::Free => Plan::Free,
577                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
578                        };
579                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
580                    }
581                }
582
583                return Err(anyhow!("Forbidden"));
584            } else if status.as_u16() >= 500 && status.as_u16() < 600 {
585                // If we encounter an error in the 500 range, retry after a delay.
586                // We've seen at least these in the wild from API providers:
587                // * 500 Internal Server Error
588                // * 502 Bad Gateway
589                // * 529 Service Overloaded
590
591                if retries_remaining == 0 {
592                    let mut body = String::new();
593                    response.body_mut().read_to_string(&mut body).await?;
594                    return Err(anyhow!(
595                        "cloud language model completion failed after {} retries with status {status}: {body}",
596                        Self::MAX_RETRIES
597                    ));
598                }
599
600                Timer::after(retry_delay).await;
601
602                retries_remaining -= 1;
603                retry_delay *= 2; // If it fails again, wait longer.
604            } else if status == StatusCode::PAYMENT_REQUIRED {
605                return Err(anyhow!(PaymentRequiredError));
606            } else {
607                let mut body = String::new();
608                response.body_mut().read_to_string(&mut body).await?;
609                return Err(anyhow!(ApiError { status, body }));
610            }
611        }
612    }
613}
614
615#[derive(Debug, Error)]
616#[error("cloud language model completion failed with status {status}: {body}")]
617struct ApiError {
618    status: StatusCode,
619    body: String,
620}
621
622impl LanguageModel for CloudLanguageModel {
623    fn id(&self) -> LanguageModelId {
624        self.id.clone()
625    }
626
627    fn name(&self) -> LanguageModelName {
628        LanguageModelName::from(self.model.display_name().to_string())
629    }
630
631    fn provider_id(&self) -> LanguageModelProviderId {
632        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
633    }
634
635    fn provider_name(&self) -> LanguageModelProviderName {
636        LanguageModelProviderName(PROVIDER_NAME.into())
637    }
638
639    fn supports_tools(&self) -> bool {
640        match self.model {
641            CloudModel::Anthropic(_) => true,
642            CloudModel::Google(_) => true,
643            CloudModel::OpenAi(_) => true,
644        }
645    }
646
647    fn telemetry_id(&self) -> String {
648        format!("zed.dev/{}", self.model.id())
649    }
650
651    fn availability(&self) -> LanguageModelAvailability {
652        self.model.availability()
653    }
654
655    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
656        self.model.tool_input_format()
657    }
658
659    fn max_token_count(&self) -> usize {
660        self.model.max_token_count()
661    }
662
663    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
664        match &self.model {
665            CloudModel::Anthropic(model) => {
666                model
667                    .cache_configuration()
668                    .map(|cache| LanguageModelCacheConfiguration {
669                        max_cache_anchors: cache.max_cache_anchors,
670                        should_speculate: cache.should_speculate,
671                        min_total_token: cache.min_total_token,
672                    })
673            }
674            CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
675        }
676    }
677
678    fn count_tokens(
679        &self,
680        request: LanguageModelRequest,
681        cx: &App,
682    ) -> BoxFuture<'static, Result<usize>> {
683        match self.model.clone() {
684            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
685            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
686            CloudModel::Google(model) => {
687                let client = self.client.clone();
688                let request = into_google(request, model.id().into());
689                let request = google_ai::CountTokensRequest {
690                    contents: request.contents,
691                };
692                async move {
693                    let request = serde_json::to_string(&request)?;
694                    let response = client
695                        .request(proto::CountLanguageModelTokens {
696                            provider: proto::LanguageModelProvider::Google as i32,
697                            request,
698                        })
699                        .await?;
700                    Ok(response.token_count as usize)
701                }
702                .boxed()
703            }
704        }
705    }
706
707    fn stream_completion(
708        &self,
709        request: LanguageModelRequest,
710        _cx: &AsyncApp,
711    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
712        match &self.model {
713            CloudModel::Anthropic(model) => {
714                let request = into_anthropic(
715                    request,
716                    model.request_id().into(),
717                    model.default_temperature(),
718                    model.max_output_tokens(),
719                    model.mode(),
720                );
721                let client = self.client.clone();
722                let llm_api_token = self.llm_api_token.clone();
723                let future = self.request_limiter.stream(async move {
724                    let response = Self::perform_llm_completion(
725                        client.clone(),
726                        llm_api_token,
727                        PerformCompletionParams {
728                            provider: client::LanguageModelProvider::Anthropic,
729                            model: request.model.clone(),
730                            provider_request: RawValue::from_string(serde_json::to_string(
731                                &request,
732                            )?)?,
733                        },
734                    )
735                    .await
736                    .map_err(|err| match err.downcast::<ApiError>() {
737                        Ok(api_err) => {
738                            if api_err.status == StatusCode::BAD_REQUEST {
739                                if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
740                                    return anyhow!(
741                                        LanguageModelKnownError::ContextWindowLimitExceeded {
742                                            tokens
743                                        }
744                                    );
745                                }
746                            }
747                            anyhow!(api_err)
748                        }
749                        Err(err) => anyhow!(err),
750                    })?;
751
752                    Ok(
753                        crate::provider::anthropic::map_to_language_model_completion_events(
754                            Box::pin(response_lines(response).map_err(AnthropicError::Other)),
755                        ),
756                    )
757                });
758                async move { Ok(future.await?.boxed()) }.boxed()
759            }
760            CloudModel::OpenAi(model) => {
761                let client = self.client.clone();
762                let request = into_open_ai(request, model, model.max_output_tokens());
763                let llm_api_token = self.llm_api_token.clone();
764                let future = self.request_limiter.stream(async move {
765                    let response = Self::perform_llm_completion(
766                        client.clone(),
767                        llm_api_token,
768                        PerformCompletionParams {
769                            provider: client::LanguageModelProvider::OpenAi,
770                            model: request.model.clone(),
771                            provider_request: RawValue::from_string(serde_json::to_string(
772                                &request,
773                            )?)?,
774                        },
775                    )
776                    .await?;
777                    Ok(
778                        crate::provider::open_ai::map_to_language_model_completion_events(
779                            Box::pin(response_lines(response)),
780                        ),
781                    )
782                });
783                async move { Ok(future.await?.boxed()) }.boxed()
784            }
785            CloudModel::Google(model) => {
786                let client = self.client.clone();
787                let request = into_google(request, model.id().into());
788                let llm_api_token = self.llm_api_token.clone();
789                let future = self.request_limiter.stream(async move {
790                    let response = Self::perform_llm_completion(
791                        client.clone(),
792                        llm_api_token,
793                        PerformCompletionParams {
794                            provider: client::LanguageModelProvider::Google,
795                            model: request.model.clone(),
796                            provider_request: RawValue::from_string(serde_json::to_string(
797                                &request,
798                            )?)?,
799                        },
800                    )
801                    .await?;
802                    Ok(
803                        crate::provider::google::map_to_language_model_completion_events(Box::pin(
804                            response_lines(response),
805                        )),
806                    )
807                });
808                async move { Ok(future.await?.boxed()) }.boxed()
809            }
810        }
811    }
812}
813
814fn response_lines<T: DeserializeOwned>(
815    response: Response<AsyncBody>,
816) -> impl Stream<Item = Result<T>> {
817    futures::stream::try_unfold(
818        (String::new(), BufReader::new(response.into_body())),
819        move |(mut line, mut body)| async {
820            match body.read_line(&mut line).await {
821                Ok(0) => Ok(None),
822                Ok(_) => {
823                    let event: T = serde_json::from_str(&line)?;
824                    line.clear();
825                    Ok(Some((event, (line, body))))
826                }
827                Err(e) => Err(e.into()),
828            }
829        },
830    )
831}
832
833struct ConfigurationView {
834    state: gpui::Entity<State>,
835}
836
837impl ConfigurationView {
838    fn authenticate(&mut self, cx: &mut Context<Self>) {
839        self.state.update(cx, |state, cx| {
840            state.authenticate(cx).detach_and_log_err(cx);
841        });
842        cx.notify();
843    }
844}
845
846impl Render for ConfigurationView {
847    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
848        const ZED_AI_URL: &str = "https://zed.dev/ai";
849
850        let is_connected = !self.state.read(cx).is_signed_out();
851        let plan = self.state.read(cx).user_store.read(cx).current_plan();
852        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
853
854        let is_pro = plan == Some(proto::Plan::ZedPro);
855        let subscription_text = Label::new(if is_pro {
856            "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
857        } else {
858            "You have basic access to models from Anthropic through the Zed AI Free plan."
859        });
860        let manage_subscription_button = if is_pro {
861            Some(
862                h_flex().child(
863                    Button::new("manage_settings", "Manage Subscription")
864                        .style(ButtonStyle::Tinted(TintColor::Accent))
865                        .on_click(
866                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
867                        ),
868                ),
869            )
870        } else if cx.has_flag::<ZedPro>() {
871            Some(
872                h_flex()
873                    .gap_2()
874                    .child(
875                        Button::new("learn_more", "Learn more")
876                            .style(ButtonStyle::Subtle)
877                            .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_AI_URL))),
878                    )
879                    .child(
880                        Button::new("upgrade", "Upgrade")
881                            .style(ButtonStyle::Subtle)
882                            .color(Color::Accent)
883                            .on_click(
884                                cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
885                            ),
886                    ),
887            )
888        } else {
889            None
890        };
891
892        if is_connected {
893            v_flex()
894                .gap_3()
895                .w_full()
896                .children(render_accept_terms(
897                    self.state.clone(),
898                    LanguageModelProviderTosView::Configuration,
899                    cx,
900                ))
901                .when(has_accepted_terms, |this| {
902                    this.child(subscription_text)
903                        .children(manage_subscription_button)
904                })
905        } else {
906            v_flex()
907                .gap_2()
908                .child(Label::new("Use Zed AI to access hosted language models."))
909                .child(
910                    Button::new("sign_in", "Sign In")
911                        .icon_color(Color::Muted)
912                        .icon(IconName::Github)
913                        .icon_position(IconPosition::Start)
914                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
915                )
916        }
917    }
918}