cloud.rs

  1use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
  2use anyhow::{Result, anyhow};
  3use client::{Client, UserStore, zed_urls};
  4use collections::BTreeMap;
  5use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
  6use futures::{
  7    AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture,
  8    stream::BoxStream,
  9};
 10use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
 11use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
 12use language_model::{
 13    AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
 14    LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
 15    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
 16    LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
 17    ZED_CLOUD_PROVIDER_ID,
 18};
 19use language_model::{
 20    LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
 21    MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
 22};
 23use proto::Plan;
 24use schemars::JsonSchema;
 25use serde::{Deserialize, Serialize, de::DeserializeOwned};
 26use settings::{Settings, SettingsStore};
 27use smol::Timer;
 28use smol::io::{AsyncReadExt, BufReader};
 29use std::str::FromStr as _;
 30use std::{
 31    sync::{Arc, LazyLock},
 32    time::Duration,
 33};
 34use strum::IntoEnumIterator;
 35use thiserror::Error;
 36use ui::{TintColor, prelude::*};
 37use zed_llm_client::{
 38    CURRENT_PLAN_HEADER_NAME, CompletionBody, EXPIRED_LLM_TOKEN_HEADER_NAME,
 39    MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
 40};
 41
 42use crate::AllLanguageModelSettings;
 43use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic};
 44use crate::provider::google::into_google;
 45use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
 46
 47pub const PROVIDER_NAME: &str = "Zed";
 48
 49const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
 50    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
 51
 52fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
 53    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
 54        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
 55            .map(|json| serde_json::from_str(json).unwrap())
 56            .unwrap_or_default()
 57    });
 58    ADDITIONAL_MODELS.as_slice()
 59}
 60
 61#[derive(Default, Clone, Debug, PartialEq)]
 62pub struct ZedDotDevSettings {
 63    pub available_models: Vec<AvailableModel>,
 64}
 65
 66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 67#[serde(rename_all = "lowercase")]
 68pub enum AvailableProvider {
 69    Anthropic,
 70    OpenAi,
 71    Google,
 72}
 73
 74#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 75pub struct AvailableModel {
 76    /// The provider of the language model.
 77    pub provider: AvailableProvider,
 78    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
 79    pub name: String,
 80    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 81    pub display_name: Option<String>,
 82    /// The size of the context window, indicating the maximum number of tokens the model can process.
 83    pub max_tokens: usize,
 84    /// The maximum number of output tokens allowed by the model.
 85    pub max_output_tokens: Option<u32>,
 86    /// The maximum number of completion tokens allowed by the model (o1-* only)
 87    pub max_completion_tokens: Option<u32>,
 88    /// Override this model with a different Anthropic model for tool calls.
 89    pub tool_override: Option<String>,
 90    /// Indicates whether this custom model supports caching.
 91    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 92    /// The default temperature to use for this model.
 93    pub default_temperature: Option<f32>,
 94    /// Any extra beta headers to provide when using the model.
 95    #[serde(default)]
 96    pub extra_beta_headers: Vec<String>,
 97    /// The model's mode (e.g. thinking)
 98    pub mode: Option<ModelMode>,
 99}
100
101#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
102#[serde(tag = "type", rename_all = "lowercase")]
103pub enum ModelMode {
104    #[default]
105    Default,
106    Thinking {
107        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
108        budget_tokens: Option<u32>,
109    },
110}
111
112impl From<ModelMode> for AnthropicModelMode {
113    fn from(value: ModelMode) -> Self {
114        match value {
115            ModelMode::Default => AnthropicModelMode::Default,
116            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
117        }
118    }
119}
120
121pub struct CloudLanguageModelProvider {
122    client: Arc<Client>,
123    state: gpui::Entity<State>,
124    _maintain_client_status: Task<()>,
125}
126
127pub struct State {
128    client: Arc<Client>,
129    llm_api_token: LlmApiToken,
130    user_store: Entity<UserStore>,
131    status: client::Status,
132    accept_terms: Option<Task<Result<()>>>,
133    _settings_subscription: Subscription,
134    _llm_token_subscription: Subscription,
135}
136
137impl State {
138    fn new(
139        client: Arc<Client>,
140        user_store: Entity<UserStore>,
141        status: client::Status,
142        cx: &mut Context<Self>,
143    ) -> Self {
144        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
145
146        Self {
147            client: client.clone(),
148            llm_api_token: LlmApiToken::default(),
149            user_store,
150            status,
151            accept_terms: None,
152            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
153                cx.notify();
154            }),
155            _llm_token_subscription: cx.subscribe(
156                &refresh_llm_token_listener,
157                |this, _listener, _event, cx| {
158                    let client = this.client.clone();
159                    let llm_api_token = this.llm_api_token.clone();
160                    cx.spawn(async move |_this, _cx| {
161                        llm_api_token.refresh(&client).await?;
162                        anyhow::Ok(())
163                    })
164                    .detach_and_log_err(cx);
165                },
166            ),
167        }
168    }
169
170    fn is_signed_out(&self) -> bool {
171        self.status.is_signed_out()
172    }
173
174    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
175        let client = self.client.clone();
176        cx.spawn(async move |this, cx| {
177            client.authenticate_and_connect(true, &cx).await?;
178            this.update(cx, |_, cx| cx.notify())
179        })
180    }
181
182    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
183        self.user_store
184            .read(cx)
185            .current_user_has_accepted_terms()
186            .unwrap_or(false)
187    }
188
189    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
190        let user_store = self.user_store.clone();
191        self.accept_terms = Some(cx.spawn(async move |this, cx| {
192            let _ = user_store
193                .update(cx, |store, cx| store.accept_terms_of_service(cx))?
194                .await;
195            this.update(cx, |this, cx| {
196                this.accept_terms = None;
197                cx.notify()
198            })
199        }));
200    }
201}
202
203impl CloudLanguageModelProvider {
204    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
205        let mut status_rx = client.status();
206        let status = *status_rx.borrow();
207
208        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
209
210        let state_ref = state.downgrade();
211        let maintain_client_status = cx.spawn(async move |cx| {
212            while let Some(status) = status_rx.next().await {
213                if let Some(this) = state_ref.upgrade() {
214                    _ = this.update(cx, |this, cx| {
215                        if this.status != status {
216                            this.status = status;
217                            cx.notify();
218                        }
219                    });
220                } else {
221                    break;
222                }
223            }
224        });
225
226        Self {
227            client,
228            state: state.clone(),
229            _maintain_client_status: maintain_client_status,
230        }
231    }
232
233    fn create_language_model(
234        &self,
235        model: CloudModel,
236        llm_api_token: LlmApiToken,
237    ) -> Arc<dyn LanguageModel> {
238        Arc::new(CloudLanguageModel {
239            id: LanguageModelId::from(model.id().to_string()),
240            model,
241            llm_api_token: llm_api_token.clone(),
242            client: self.client.clone(),
243            request_limiter: RateLimiter::new(4),
244        }) as Arc<dyn LanguageModel>
245    }
246}
247
248impl LanguageModelProviderState for CloudLanguageModelProvider {
249    type ObservableEntity = State;
250
251    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
252        Some(self.state.clone())
253    }
254}
255
256impl LanguageModelProvider for CloudLanguageModelProvider {
257    fn id(&self) -> LanguageModelProviderId {
258        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
259    }
260
261    fn name(&self) -> LanguageModelProviderName {
262        LanguageModelProviderName(PROVIDER_NAME.into())
263    }
264
265    fn icon(&self) -> IconName {
266        IconName::AiZed
267    }
268
269    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
270        let llm_api_token = self.state.read(cx).llm_api_token.clone();
271        let model = CloudModel::Anthropic(anthropic::Model::default());
272        Some(Arc::new(CloudLanguageModel {
273            id: LanguageModelId::from(model.id().to_string()),
274            model,
275            llm_api_token: llm_api_token.clone(),
276            client: self.client.clone(),
277            request_limiter: RateLimiter::new(4),
278        }))
279    }
280
281    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
282        let llm_api_token = self.state.read(cx).llm_api_token.clone();
283        [
284            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
285            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
286        ]
287        .into_iter()
288        .map(|model| self.create_language_model(model, llm_api_token.clone()))
289        .collect()
290    }
291
292    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
293        let mut models = BTreeMap::default();
294
295        if cx.is_staff() {
296            for model in anthropic::Model::iter() {
297                if !matches!(model, anthropic::Model::Custom { .. }) {
298                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
299                }
300            }
301            for model in open_ai::Model::iter() {
302                if !matches!(model, open_ai::Model::Custom { .. }) {
303                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
304                }
305            }
306            for model in google_ai::Model::iter() {
307                if !matches!(model, google_ai::Model::Custom { .. }) {
308                    models.insert(model.id().to_string(), CloudModel::Google(model));
309                }
310            }
311        } else {
312            models.insert(
313                anthropic::Model::Claude3_5Sonnet.id().to_string(),
314                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
315            );
316            models.insert(
317                anthropic::Model::Claude3_7Sonnet.id().to_string(),
318                CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
319            );
320            models.insert(
321                anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
322                CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
323            );
324        }
325
326        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
327            zed_cloud_provider_additional_models()
328        } else {
329            &[]
330        };
331
332        // Override with available models from settings
333        for model in AllLanguageModelSettings::get_global(cx)
334            .zed_dot_dev
335            .available_models
336            .iter()
337            .chain(llm_closed_beta_models)
338            .cloned()
339        {
340            let model = match model.provider {
341                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
342                    name: model.name.clone(),
343                    display_name: model.display_name.clone(),
344                    max_tokens: model.max_tokens,
345                    tool_override: model.tool_override.clone(),
346                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
347                        anthropic::AnthropicModelCacheConfiguration {
348                            max_cache_anchors: config.max_cache_anchors,
349                            should_speculate: config.should_speculate,
350                            min_total_token: config.min_total_token,
351                        }
352                    }),
353                    default_temperature: model.default_temperature,
354                    max_output_tokens: model.max_output_tokens,
355                    extra_beta_headers: model.extra_beta_headers.clone(),
356                    mode: model.mode.unwrap_or_default().into(),
357                }),
358                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
359                    name: model.name.clone(),
360                    display_name: model.display_name.clone(),
361                    max_tokens: model.max_tokens,
362                    max_output_tokens: model.max_output_tokens,
363                    max_completion_tokens: model.max_completion_tokens,
364                }),
365                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
366                    name: model.name.clone(),
367                    display_name: model.display_name.clone(),
368                    max_tokens: model.max_tokens,
369                }),
370            };
371            models.insert(model.id().to_string(), model.clone());
372        }
373
374        let llm_api_token = self.state.read(cx).llm_api_token.clone();
375        models
376            .into_values()
377            .map(|model| self.create_language_model(model, llm_api_token.clone()))
378            .collect()
379    }
380
381    fn is_authenticated(&self, cx: &App) -> bool {
382        !self.state.read(cx).is_signed_out()
383    }
384
385    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
386        Task::ready(Ok(()))
387    }
388
389    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
390        cx.new(|_| ConfigurationView {
391            state: self.state.clone(),
392        })
393        .into()
394    }
395
396    fn must_accept_terms(&self, cx: &App) -> bool {
397        !self.state.read(cx).has_accepted_terms_of_service(cx)
398    }
399
400    fn render_accept_terms(
401        &self,
402        view: LanguageModelProviderTosView,
403        cx: &mut App,
404    ) -> Option<AnyElement> {
405        render_accept_terms(self.state.clone(), view, cx)
406    }
407
408    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
409        Task::ready(Ok(()))
410    }
411}
412
413fn render_accept_terms(
414    state: Entity<State>,
415    view_kind: LanguageModelProviderTosView,
416    cx: &mut App,
417) -> Option<AnyElement> {
418    if state.read(cx).has_accepted_terms_of_service(cx) {
419        return None;
420    }
421
422    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
423
424    let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
425    let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
426
427    let terms_button = Button::new("terms_of_service", "Terms of Service")
428        .style(ButtonStyle::Subtle)
429        .icon(IconName::ArrowUpRight)
430        .icon_color(Color::Muted)
431        .icon_size(IconSize::XSmall)
432        .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
433        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
434
435    let button_container = h_flex().child(
436        Button::new("accept_terms", "I accept the Terms of Service")
437            .when(!thread_empty_state, |this| {
438                this.full_width()
439                    .style(ButtonStyle::Tinted(TintColor::Accent))
440                    .icon(IconName::Check)
441                    .icon_position(IconPosition::Start)
442                    .icon_size(IconSize::Small)
443            })
444            .when(thread_empty_state, |this| {
445                this.style(ButtonStyle::Tinted(TintColor::Warning))
446                    .label_size(LabelSize::Small)
447            })
448            .disabled(accept_terms_disabled)
449            .on_click({
450                let state = state.downgrade();
451                move |_, _window, cx| {
452                    state
453                        .update(cx, |state, cx| state.accept_terms_of_service(cx))
454                        .ok();
455                }
456            }),
457    );
458
459    let form = if thread_empty_state {
460        h_flex()
461            .w_full()
462            .flex_wrap()
463            .justify_between()
464            .child(
465                h_flex()
466                    .child(
467                        Label::new("To start using Zed AI, please read and accept the")
468                            .size(LabelSize::Small),
469                    )
470                    .child(terms_button),
471            )
472            .child(button_container)
473    } else {
474        v_flex()
475            .w_full()
476            .gap_2()
477            .child(
478                h_flex()
479                    .flex_wrap()
480                    .when(thread_fresh_start, |this| this.justify_center())
481                    .child(Label::new(
482                        "To start using Zed AI, please read and accept the",
483                    ))
484                    .child(terms_button),
485            )
486            .child({
487                match view_kind {
488                    LanguageModelProviderTosView::PromptEditorPopup => {
489                        button_container.w_full().justify_end()
490                    }
491                    LanguageModelProviderTosView::Configuration => {
492                        button_container.w_full().justify_start()
493                    }
494                    LanguageModelProviderTosView::ThreadFreshStart => {
495                        button_container.w_full().justify_center()
496                    }
497                    LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
498                }
499            })
500    };
501
502    Some(form.into_any())
503}
504
505pub struct CloudLanguageModel {
506    id: LanguageModelId,
507    model: CloudModel,
508    llm_api_token: LlmApiToken,
509    client: Arc<Client>,
510    request_limiter: RateLimiter,
511}
512
513impl CloudLanguageModel {
514    const MAX_RETRIES: usize = 3;
515
516    async fn perform_llm_completion(
517        client: Arc<Client>,
518        llm_api_token: LlmApiToken,
519        body: CompletionBody,
520    ) -> Result<Response<AsyncBody>> {
521        let http_client = &client.http_client();
522
523        let mut token = llm_api_token.acquire(&client).await?;
524        let mut retries_remaining = Self::MAX_RETRIES;
525        let mut retry_delay = Duration::from_secs(1);
526
527        loop {
528            let request_builder = http_client::Request::builder().method(Method::POST);
529            let request_builder = if let Ok(completions_url) = std::env::var("ZED_COMPLETIONS_URL")
530            {
531                request_builder.uri(completions_url)
532            } else {
533                request_builder.uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
534            };
535            let request = request_builder
536                .header("Content-Type", "application/json")
537                .header("Authorization", format!("Bearer {token}"))
538                .body(serde_json::to_string(&body)?.into())?;
539            let mut response = http_client.send(request).await?;
540            let status = response.status();
541            if status.is_success() {
542                return Ok(response);
543            } else if response
544                .headers()
545                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
546                .is_some()
547            {
548                retries_remaining -= 1;
549                token = llm_api_token.refresh(&client).await?;
550            } else if status == StatusCode::FORBIDDEN
551                && response
552                    .headers()
553                    .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
554                    .is_some()
555            {
556                return Err(anyhow!(MaxMonthlySpendReachedError));
557            } else if status == StatusCode::FORBIDDEN
558                && response
559                    .headers()
560                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
561                    .is_some()
562            {
563                if let Some("model_requests") = response
564                    .headers()
565                    .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
566                    .and_then(|resource| resource.to_str().ok())
567                {
568                    if let Some(plan) = response
569                        .headers()
570                        .get(CURRENT_PLAN_HEADER_NAME)
571                        .and_then(|plan| plan.to_str().ok())
572                        .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
573                    {
574                        let plan = match plan {
575                            zed_llm_client::Plan::Free => Plan::Free,
576                            zed_llm_client::Plan::ZedPro => Plan::ZedPro,
577                        };
578                        return Err(anyhow!(ModelRequestLimitReachedError { plan }));
579                    }
580                }
581
582                return Err(anyhow!("Forbidden"));
583            } else if status.as_u16() >= 500 && status.as_u16() < 600 {
584                // If we encounter an error in the 500 range, retry after a delay.
585                // We've seen at least these in the wild from API providers:
586                // * 500 Internal Server Error
587                // * 502 Bad Gateway
588                // * 529 Service Overloaded
589
590                if retries_remaining == 0 {
591                    let mut body = String::new();
592                    response.body_mut().read_to_string(&mut body).await?;
593                    return Err(anyhow!(
594                        "cloud language model completion failed after {} retries with status {status}: {body}",
595                        Self::MAX_RETRIES
596                    ));
597                }
598
599                Timer::after(retry_delay).await;
600
601                retries_remaining -= 1;
602                retry_delay *= 2; // If it fails again, wait longer.
603            } else if status == StatusCode::PAYMENT_REQUIRED {
604                return Err(anyhow!(PaymentRequiredError));
605            } else {
606                let mut body = String::new();
607                response.body_mut().read_to_string(&mut body).await?;
608                return Err(anyhow!(ApiError { status, body }));
609            }
610        }
611    }
612}
613
614#[derive(Debug, Error)]
615#[error("cloud language model completion failed with status {status}: {body}")]
616struct ApiError {
617    status: StatusCode,
618    body: String,
619}
620
621impl LanguageModel for CloudLanguageModel {
622    fn id(&self) -> LanguageModelId {
623        self.id.clone()
624    }
625
626    fn name(&self) -> LanguageModelName {
627        LanguageModelName::from(self.model.display_name().to_string())
628    }
629
630    fn provider_id(&self) -> LanguageModelProviderId {
631        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
632    }
633
634    fn provider_name(&self) -> LanguageModelProviderName {
635        LanguageModelProviderName(PROVIDER_NAME.into())
636    }
637
638    fn supports_tools(&self) -> bool {
639        match self.model {
640            CloudModel::Anthropic(_) => true,
641            CloudModel::Google(_) => true,
642            CloudModel::OpenAi(_) => true,
643        }
644    }
645
646    fn telemetry_id(&self) -> String {
647        format!("zed.dev/{}", self.model.id())
648    }
649
650    fn availability(&self) -> LanguageModelAvailability {
651        self.model.availability()
652    }
653
654    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
655        self.model.tool_input_format()
656    }
657
658    fn max_token_count(&self) -> usize {
659        self.model.max_token_count()
660    }
661
662    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
663        match &self.model {
664            CloudModel::Anthropic(model) => {
665                model
666                    .cache_configuration()
667                    .map(|cache| LanguageModelCacheConfiguration {
668                        max_cache_anchors: cache.max_cache_anchors,
669                        should_speculate: cache.should_speculate,
670                        min_total_token: cache.min_total_token,
671                    })
672            }
673            CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
674        }
675    }
676
677    fn count_tokens(
678        &self,
679        request: LanguageModelRequest,
680        cx: &App,
681    ) -> BoxFuture<'static, Result<usize>> {
682        match self.model.clone() {
683            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
684            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
685            CloudModel::Google(model) => {
686                let client = self.client.clone();
687                let request = into_google(request, model.id().into());
688                let request = google_ai::CountTokensRequest {
689                    contents: request.contents,
690                };
691                async move {
692                    let request = serde_json::to_string(&request)?;
693                    let response = client
694                        .request(proto::CountLanguageModelTokens {
695                            provider: proto::LanguageModelProvider::Google as i32,
696                            request,
697                        })
698                        .await?;
699                    Ok(response.token_count as usize)
700                }
701                .boxed()
702            }
703        }
704    }
705
706    fn stream_completion(
707        &self,
708        request: LanguageModelRequest,
709        _cx: &AsyncApp,
710    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
711        match &self.model {
712            CloudModel::Anthropic(model) => {
713                let request = into_anthropic(
714                    request,
715                    model.request_id().into(),
716                    model.default_temperature(),
717                    model.max_output_tokens(),
718                    model.mode(),
719                );
720                let client = self.client.clone();
721                let llm_api_token = self.llm_api_token.clone();
722                let future = self.request_limiter.stream(async move {
723                    let response = Self::perform_llm_completion(
724                        client.clone(),
725                        llm_api_token,
726                        CompletionBody {
727                            provider: zed_llm_client::LanguageModelProvider::Anthropic,
728                            model: request.model.clone(),
729                            provider_request: serde_json::to_value(&request)?,
730                        },
731                    )
732                    .await
733                    .map_err(|err| match err.downcast::<ApiError>() {
734                        Ok(api_err) => {
735                            if api_err.status == StatusCode::BAD_REQUEST {
736                                if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
737                                    return anyhow!(
738                                        LanguageModelKnownError::ContextWindowLimitExceeded {
739                                            tokens
740                                        }
741                                    );
742                                }
743                            }
744                            anyhow!(api_err)
745                        }
746                        Err(err) => anyhow!(err),
747                    })?;
748
749                    Ok(
750                        crate::provider::anthropic::map_to_language_model_completion_events(
751                            Box::pin(response_lines(response).map_err(AnthropicError::Other)),
752                        ),
753                    )
754                });
755                async move { Ok(future.await?.boxed()) }.boxed()
756            }
757            CloudModel::OpenAi(model) => {
758                let client = self.client.clone();
759                let request = into_open_ai(request, model, model.max_output_tokens());
760                let llm_api_token = self.llm_api_token.clone();
761                let future = self.request_limiter.stream(async move {
762                    let response = Self::perform_llm_completion(
763                        client.clone(),
764                        llm_api_token,
765                        CompletionBody {
766                            provider: zed_llm_client::LanguageModelProvider::OpenAi,
767                            model: request.model.clone(),
768                            provider_request: serde_json::to_value(&request)?,
769                        },
770                    )
771                    .await?;
772                    Ok(
773                        crate::provider::open_ai::map_to_language_model_completion_events(
774                            Box::pin(response_lines(response)),
775                        ),
776                    )
777                });
778                async move { Ok(future.await?.boxed()) }.boxed()
779            }
780            CloudModel::Google(model) => {
781                let client = self.client.clone();
782                let request = into_google(request, model.id().into());
783                let llm_api_token = self.llm_api_token.clone();
784                let future = self.request_limiter.stream(async move {
785                    let response = Self::perform_llm_completion(
786                        client.clone(),
787                        llm_api_token,
788                        CompletionBody {
789                            provider: zed_llm_client::LanguageModelProvider::Google,
790                            model: request.model.clone(),
791                            provider_request: serde_json::to_value(&request)?,
792                        },
793                    )
794                    .await?;
795                    Ok(
796                        crate::provider::google::map_to_language_model_completion_events(Box::pin(
797                            response_lines(response),
798                        )),
799                    )
800                });
801                async move { Ok(future.await?.boxed()) }.boxed()
802            }
803        }
804    }
805}
806
807fn response_lines<T: DeserializeOwned>(
808    response: Response<AsyncBody>,
809) -> impl Stream<Item = Result<T>> {
810    futures::stream::try_unfold(
811        (String::new(), BufReader::new(response.into_body())),
812        move |(mut line, mut body)| async {
813            match body.read_line(&mut line).await {
814                Ok(0) => Ok(None),
815                Ok(_) => {
816                    let event: T = serde_json::from_str(&line)?;
817                    line.clear();
818                    Ok(Some((event, (line, body))))
819                }
820                Err(e) => Err(e.into()),
821            }
822        },
823    )
824}
825
826struct ConfigurationView {
827    state: gpui::Entity<State>,
828}
829
830impl ConfigurationView {
831    fn authenticate(&mut self, cx: &mut Context<Self>) {
832        self.state.update(cx, |state, cx| {
833            state.authenticate(cx).detach_and_log_err(cx);
834        });
835        cx.notify();
836    }
837}
838
839impl Render for ConfigurationView {
840    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
841        const ZED_AI_URL: &str = "https://zed.dev/ai";
842
843        let is_connected = !self.state.read(cx).is_signed_out();
844        let plan = self.state.read(cx).user_store.read(cx).current_plan();
845        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
846
847        let is_pro = plan == Some(proto::Plan::ZedPro);
848        let subscription_text = Label::new(if is_pro {
849            "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
850        } else {
851            "You have basic access to models from Anthropic through the Zed AI Free plan."
852        });
853        let manage_subscription_button = if is_pro {
854            Some(
855                h_flex().child(
856                    Button::new("manage_settings", "Manage Subscription")
857                        .style(ButtonStyle::Tinted(TintColor::Accent))
858                        .on_click(
859                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
860                        ),
861                ),
862            )
863        } else if cx.has_flag::<ZedPro>() {
864            Some(
865                h_flex()
866                    .gap_2()
867                    .child(
868                        Button::new("learn_more", "Learn more")
869                            .style(ButtonStyle::Subtle)
870                            .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_AI_URL))),
871                    )
872                    .child(
873                        Button::new("upgrade", "Upgrade")
874                            .style(ButtonStyle::Subtle)
875                            .color(Color::Accent)
876                            .on_click(
877                                cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
878                            ),
879                    ),
880            )
881        } else {
882            None
883        };
884
885        if is_connected {
886            v_flex()
887                .gap_3()
888                .w_full()
889                .children(render_accept_terms(
890                    self.state.clone(),
891                    LanguageModelProviderTosView::Configuration,
892                    cx,
893                ))
894                .when(has_accepted_terms, |this| {
895                    this.child(subscription_text)
896                        .children(manage_subscription_button)
897                })
898        } else {
899            v_flex()
900                .gap_2()
901                .child(Label::new("Use Zed AI to access hosted language models."))
902                .child(
903                    Button::new("sign_in", "Sign In")
904                        .icon_color(Color::Muted)
905                        .icon(IconName::Github)
906                        .icon_position(IconPosition::Start)
907                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
908                )
909        }
910    }
911}