cloud.rs

  1use anthropic::AnthropicError;
  2use anyhow::{anyhow, Result};
  3use client::{
  4    zed_urls, Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
  5    MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
  6};
  7use collections::BTreeMap;
  8use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
  9use futures::{
 10    future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
 11    TryStreamExt as _,
 12};
 13use gpui::{
 14    AnyElement, AnyView, App, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal,
 15    Subscription, Task,
 16};
 17use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
 18use language_model::{
 19    AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
 20    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
 21    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, RateLimiter,
 22    ZED_CLOUD_PROVIDER_ID,
 23};
 24use language_model::{
 25    LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
 26};
 27use proto::TypedEnvelope;
 28use schemars::JsonSchema;
 29use serde::{de::DeserializeOwned, Deserialize, Serialize};
 30use serde_json::value::RawValue;
 31use settings::{Settings, SettingsStore};
 32use smol::{
 33    io::{AsyncReadExt, BufReader},
 34    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 35};
 36use std::fmt;
 37use std::{
 38    future,
 39    sync::{Arc, LazyLock},
 40};
 41use strum::IntoEnumIterator;
 42use thiserror::Error;
 43use ui::{prelude::*, TintColor};
 44
 45use crate::provider::anthropic::{
 46    count_anthropic_tokens, into_anthropic, map_to_language_model_completion_events,
 47};
 48use crate::provider::google::into_google;
 49use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
 50use crate::AllLanguageModelSettings;
 51
 52pub const PROVIDER_NAME: &str = "Zed";
 53
 54const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
 55    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
 56
 57fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
 58    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
 59        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
 60            .map(|json| serde_json::from_str(json).unwrap())
 61            .unwrap_or_default()
 62    });
 63    ADDITIONAL_MODELS.as_slice()
 64}
 65
 66#[derive(Default, Clone, Debug, PartialEq)]
 67pub struct ZedDotDevSettings {
 68    pub available_models: Vec<AvailableModel>,
 69}
 70
 71#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 72#[serde(rename_all = "lowercase")]
 73pub enum AvailableProvider {
 74    Anthropic,
 75    OpenAi,
 76    Google,
 77}
 78
 79#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 80pub struct AvailableModel {
 81    /// The provider of the language model.
 82    pub provider: AvailableProvider,
 83    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
 84    pub name: String,
 85    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 86    pub display_name: Option<String>,
 87    /// The size of the context window, indicating the maximum number of tokens the model can process.
 88    pub max_tokens: usize,
 89    /// The maximum number of output tokens allowed by the model.
 90    pub max_output_tokens: Option<u32>,
 91    /// The maximum number of completion tokens allowed by the model (o1-* only)
 92    pub max_completion_tokens: Option<u32>,
 93    /// Override this model with a different Anthropic model for tool calls.
 94    pub tool_override: Option<String>,
 95    /// Indicates whether this custom model supports caching.
 96    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 97    /// The default temperature to use for this model.
 98    pub default_temperature: Option<f32>,
 99    /// Any extra beta headers to provide when using the model.
100    #[serde(default)]
101    pub extra_beta_headers: Vec<String>,
102}
103
104struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
105
106impl Global for GlobalRefreshLlmTokenListener {}
107
108pub struct RefreshLlmTokenEvent;
109
110pub struct RefreshLlmTokenListener {
111    _llm_token_subscription: client::Subscription,
112}
113
114impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
115
116impl RefreshLlmTokenListener {
117    pub fn register(client: Arc<Client>, cx: &mut App) {
118        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
119        cx.set_global(GlobalRefreshLlmTokenListener(listener));
120    }
121
122    pub fn global(cx: &App) -> Entity<Self> {
123        GlobalRefreshLlmTokenListener::global(cx).0.clone()
124    }
125
126    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
127        Self {
128            _llm_token_subscription: client
129                .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token),
130        }
131    }
132
133    async fn handle_refresh_llm_token(
134        this: Entity<Self>,
135        _: TypedEnvelope<proto::RefreshLlmToken>,
136        mut cx: AsyncApp,
137    ) -> Result<()> {
138        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
139    }
140}
141
142pub struct CloudLanguageModelProvider {
143    client: Arc<Client>,
144    state: gpui::Entity<State>,
145    _maintain_client_status: Task<()>,
146}
147
148pub struct State {
149    client: Arc<Client>,
150    llm_api_token: LlmApiToken,
151    user_store: Entity<UserStore>,
152    status: client::Status,
153    accept_terms: Option<Task<Result<()>>>,
154    _settings_subscription: Subscription,
155    _llm_token_subscription: Subscription,
156}
157
158impl State {
159    fn new(
160        client: Arc<Client>,
161        user_store: Entity<UserStore>,
162        status: client::Status,
163        cx: &mut Context<Self>,
164    ) -> Self {
165        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
166
167        Self {
168            client: client.clone(),
169            llm_api_token: LlmApiToken::default(),
170            user_store,
171            status,
172            accept_terms: None,
173            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
174                cx.notify();
175            }),
176            _llm_token_subscription: cx.subscribe(
177                &refresh_llm_token_listener,
178                |this, _listener, _event, cx| {
179                    let client = this.client.clone();
180                    let llm_api_token = this.llm_api_token.clone();
181                    cx.spawn(|_this, _cx| async move {
182                        llm_api_token.refresh(&client).await?;
183                        anyhow::Ok(())
184                    })
185                    .detach_and_log_err(cx);
186                },
187            ),
188        }
189    }
190
191    fn is_signed_out(&self) -> bool {
192        self.status.is_signed_out()
193    }
194
195    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
196        let client = self.client.clone();
197        cx.spawn(move |this, mut cx| async move {
198            client.authenticate_and_connect(true, &cx).await?;
199            this.update(&mut cx, |_, cx| cx.notify())
200        })
201    }
202
203    fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
204        self.user_store
205            .read(cx)
206            .current_user_has_accepted_terms()
207            .unwrap_or(false)
208    }
209
210    fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
211        let user_store = self.user_store.clone();
212        self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
213            let _ = user_store
214                .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
215                .await;
216            this.update(&mut cx, |this, cx| {
217                this.accept_terms = None;
218                cx.notify()
219            })
220        }));
221    }
222}
223
224impl CloudLanguageModelProvider {
225    pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
226        let mut status_rx = client.status();
227        let status = *status_rx.borrow();
228
229        let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
230
231        let state_ref = state.downgrade();
232        let maintain_client_status = cx.spawn(|mut cx| async move {
233            while let Some(status) = status_rx.next().await {
234                if let Some(this) = state_ref.upgrade() {
235                    _ = this.update(&mut cx, |this, cx| {
236                        if this.status != status {
237                            this.status = status;
238                            cx.notify();
239                        }
240                    });
241                } else {
242                    break;
243                }
244            }
245        });
246
247        Self {
248            client,
249            state: state.clone(),
250            _maintain_client_status: maintain_client_status,
251        }
252    }
253}
254
255impl LanguageModelProviderState for CloudLanguageModelProvider {
256    type ObservableEntity = State;
257
258    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
259        Some(self.state.clone())
260    }
261}
262
263impl LanguageModelProvider for CloudLanguageModelProvider {
264    fn id(&self) -> LanguageModelProviderId {
265        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
266    }
267
268    fn name(&self) -> LanguageModelProviderName {
269        LanguageModelProviderName(PROVIDER_NAME.into())
270    }
271
272    fn icon(&self) -> IconName {
273        IconName::AiZed
274    }
275
276    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
277        let llm_api_token = self.state.read(cx).llm_api_token.clone();
278        let model = CloudModel::Anthropic(anthropic::Model::default());
279        Some(Arc::new(CloudLanguageModel {
280            id: LanguageModelId::from(model.id().to_string()),
281            model,
282            llm_api_token: llm_api_token.clone(),
283            client: self.client.clone(),
284            request_limiter: RateLimiter::new(4),
285        }))
286    }
287
288    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
289        let mut models = BTreeMap::default();
290
291        if cx.is_staff() {
292            for model in anthropic::Model::iter() {
293                if !matches!(model, anthropic::Model::Custom { .. }) {
294                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
295                }
296            }
297            for model in open_ai::Model::iter() {
298                if !matches!(model, open_ai::Model::Custom { .. }) {
299                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
300                }
301            }
302            for model in google_ai::Model::iter() {
303                if !matches!(model, google_ai::Model::Custom { .. }) {
304                    models.insert(model.id().to_string(), CloudModel::Google(model));
305                }
306            }
307        } else {
308            models.insert(
309                anthropic::Model::Claude3_5Sonnet.id().to_string(),
310                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
311            );
312        }
313
314        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
315            zed_cloud_provider_additional_models()
316        } else {
317            &[]
318        };
319
320        // Override with available models from settings
321        for model in AllLanguageModelSettings::get_global(cx)
322            .zed_dot_dev
323            .available_models
324            .iter()
325            .chain(llm_closed_beta_models)
326            .cloned()
327        {
328            let model = match model.provider {
329                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
330                    name: model.name.clone(),
331                    display_name: model.display_name.clone(),
332                    max_tokens: model.max_tokens,
333                    tool_override: model.tool_override.clone(),
334                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
335                        anthropic::AnthropicModelCacheConfiguration {
336                            max_cache_anchors: config.max_cache_anchors,
337                            should_speculate: config.should_speculate,
338                            min_total_token: config.min_total_token,
339                        }
340                    }),
341                    default_temperature: model.default_temperature,
342                    max_output_tokens: model.max_output_tokens,
343                    extra_beta_headers: model.extra_beta_headers.clone(),
344                }),
345                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
346                    name: model.name.clone(),
347                    display_name: model.display_name.clone(),
348                    max_tokens: model.max_tokens,
349                    max_output_tokens: model.max_output_tokens,
350                    max_completion_tokens: model.max_completion_tokens,
351                }),
352                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
353                    name: model.name.clone(),
354                    display_name: model.display_name.clone(),
355                    max_tokens: model.max_tokens,
356                }),
357            };
358            models.insert(model.id().to_string(), model.clone());
359        }
360
361        let llm_api_token = self.state.read(cx).llm_api_token.clone();
362        models
363            .into_values()
364            .map(|model| {
365                Arc::new(CloudLanguageModel {
366                    id: LanguageModelId::from(model.id().to_string()),
367                    model,
368                    llm_api_token: llm_api_token.clone(),
369                    client: self.client.clone(),
370                    request_limiter: RateLimiter::new(4),
371                }) as Arc<dyn LanguageModel>
372            })
373            .collect()
374    }
375
376    fn is_authenticated(&self, cx: &App) -> bool {
377        !self.state.read(cx).is_signed_out()
378    }
379
380    fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
381        Task::ready(Ok(()))
382    }
383
384    fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
385        cx.new(|_| ConfigurationView {
386            state: self.state.clone(),
387        })
388        .into()
389    }
390
391    fn must_accept_terms(&self, cx: &App) -> bool {
392        !self.state.read(cx).has_accepted_terms_of_service(cx)
393    }
394
395    fn render_accept_terms(
396        &self,
397        view: LanguageModelProviderTosView,
398        cx: &mut App,
399    ) -> Option<AnyElement> {
400        render_accept_terms(self.state.clone(), view, cx)
401    }
402
403    fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
404        Task::ready(Ok(()))
405    }
406}
407
408fn render_accept_terms(
409    state: Entity<State>,
410    view_kind: LanguageModelProviderTosView,
411    cx: &mut App,
412) -> Option<AnyElement> {
413    if state.read(cx).has_accepted_terms_of_service(cx) {
414        return None;
415    }
416
417    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
418
419    let terms_button = Button::new("terms_of_service", "Terms of Service")
420        .style(ButtonStyle::Subtle)
421        .icon(IconName::ArrowUpRight)
422        .icon_color(Color::Muted)
423        .icon_size(IconSize::XSmall)
424        .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
425
426    let text = "To start using Zed AI, please read and accept the";
427
428    let form = v_flex()
429        .w_full()
430        .gap_2()
431        .when(
432            view_kind == LanguageModelProviderTosView::ThreadEmptyState,
433            |form| form.items_center(),
434        )
435        .child(
436            h_flex()
437                .flex_wrap()
438                .when(
439                    view_kind == LanguageModelProviderTosView::ThreadEmptyState,
440                    |form| form.justify_center(),
441                )
442                .child(Label::new(text))
443                .child(terms_button),
444        )
445        .child({
446            let button_container = h_flex().w_full().child(
447                Button::new("accept_terms", "I accept the Terms of Service")
448                    .style(ButtonStyle::Tinted(TintColor::Accent))
449                    .disabled(accept_terms_disabled)
450                    .on_click({
451                        let state = state.downgrade();
452                        move |_, _window, cx| {
453                            state
454                                .update(cx, |state, cx| state.accept_terms_of_service(cx))
455                                .ok();
456                        }
457                    }),
458            );
459
460            match view_kind {
461                LanguageModelProviderTosView::ThreadEmptyState => button_container.justify_center(),
462                LanguageModelProviderTosView::PromptEditorPopup => button_container.justify_end(),
463                LanguageModelProviderTosView::Configuration => button_container.justify_start(),
464            }
465        });
466
467    Some(form.into_any())
468}
469
470pub struct CloudLanguageModel {
471    id: LanguageModelId,
472    model: CloudModel,
473    llm_api_token: LlmApiToken,
474    client: Arc<Client>,
475    request_limiter: RateLimiter,
476}
477
478#[derive(Clone, Default)]
479pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
480
481#[derive(Error, Debug)]
482pub struct PaymentRequiredError;
483
484impl fmt::Display for PaymentRequiredError {
485    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
486        write!(
487            f,
488            "Payment required to use this language model. Please upgrade your account."
489        )
490    }
491}
492
493#[derive(Error, Debug)]
494pub struct MaxMonthlySpendReachedError;
495
496impl fmt::Display for MaxMonthlySpendReachedError {
497    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
498        write!(
499            f,
500            "Maximum spending limit reached for this month. For more usage, increase your spending limit."
501        )
502    }
503}
504
505impl CloudLanguageModel {
506    async fn perform_llm_completion(
507        client: Arc<Client>,
508        llm_api_token: LlmApiToken,
509        body: PerformCompletionParams,
510    ) -> Result<Response<AsyncBody>> {
511        let http_client = &client.http_client();
512
513        let mut token = llm_api_token.acquire(&client).await?;
514        let mut did_retry = false;
515
516        let response = loop {
517            let request_builder = http_client::Request::builder();
518            let request = request_builder
519                .method(Method::POST)
520                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
521                .header("Content-Type", "application/json")
522                .header("Authorization", format!("Bearer {token}"))
523                .body(serde_json::to_string(&body)?.into())?;
524            let mut response = http_client.send(request).await?;
525            if response.status().is_success() {
526                break response;
527            } else if !did_retry
528                && response
529                    .headers()
530                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
531                    .is_some()
532            {
533                did_retry = true;
534                token = llm_api_token.refresh(&client).await?;
535            } else if response.status() == StatusCode::FORBIDDEN
536                && response
537                    .headers()
538                    .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
539                    .is_some()
540            {
541                break Err(anyhow!(MaxMonthlySpendReachedError))?;
542            } else if response.status() == StatusCode::PAYMENT_REQUIRED {
543                break Err(anyhow!(PaymentRequiredError))?;
544            } else {
545                let mut body = String::new();
546                response.body_mut().read_to_string(&mut body).await?;
547                break Err(anyhow!(
548                    "cloud language model completion failed with status {}: {body}",
549                    response.status()
550                ))?;
551            }
552        };
553
554        Ok(response)
555    }
556}
557
558impl LanguageModel for CloudLanguageModel {
559    fn id(&self) -> LanguageModelId {
560        self.id.clone()
561    }
562
563    fn name(&self) -> LanguageModelName {
564        LanguageModelName::from(self.model.display_name().to_string())
565    }
566
567    fn icon(&self) -> Option<IconName> {
568        self.model.icon()
569    }
570
571    fn provider_id(&self) -> LanguageModelProviderId {
572        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
573    }
574
575    fn provider_name(&self) -> LanguageModelProviderName {
576        LanguageModelProviderName(PROVIDER_NAME.into())
577    }
578
579    fn telemetry_id(&self) -> String {
580        format!("zed.dev/{}", self.model.id())
581    }
582
583    fn availability(&self) -> LanguageModelAvailability {
584        self.model.availability()
585    }
586
587    fn max_token_count(&self) -> usize {
588        self.model.max_token_count()
589    }
590
591    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
592        match &self.model {
593            CloudModel::Anthropic(model) => {
594                model
595                    .cache_configuration()
596                    .map(|cache| LanguageModelCacheConfiguration {
597                        max_cache_anchors: cache.max_cache_anchors,
598                        should_speculate: cache.should_speculate,
599                        min_total_token: cache.min_total_token,
600                    })
601            }
602            CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
603        }
604    }
605
606    fn count_tokens(
607        &self,
608        request: LanguageModelRequest,
609        cx: &App,
610    ) -> BoxFuture<'static, Result<usize>> {
611        match self.model.clone() {
612            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
613            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
614            CloudModel::Google(model) => {
615                let client = self.client.clone();
616                let request = into_google(request, model.id().into());
617                let request = google_ai::CountTokensRequest {
618                    contents: request.contents,
619                };
620                async move {
621                    let request = serde_json::to_string(&request)?;
622                    let response = client
623                        .request(proto::CountLanguageModelTokens {
624                            provider: proto::LanguageModelProvider::Google as i32,
625                            request,
626                        })
627                        .await?;
628                    Ok(response.token_count as usize)
629                }
630                .boxed()
631            }
632        }
633    }
634
635    fn stream_completion(
636        &self,
637        request: LanguageModelRequest,
638        _cx: &AsyncApp,
639    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
640        match &self.model {
641            CloudModel::Anthropic(model) => {
642                let request = into_anthropic(
643                    request,
644                    model.id().into(),
645                    model.default_temperature(),
646                    model.max_output_tokens(),
647                );
648                let client = self.client.clone();
649                let llm_api_token = self.llm_api_token.clone();
650                let future = self.request_limiter.stream(async move {
651                    let response = Self::perform_llm_completion(
652                        client.clone(),
653                        llm_api_token,
654                        PerformCompletionParams {
655                            provider: client::LanguageModelProvider::Anthropic,
656                            model: request.model.clone(),
657                            provider_request: RawValue::from_string(serde_json::to_string(
658                                &request,
659                            )?)?,
660                        },
661                    )
662                    .await?;
663                    Ok(map_to_language_model_completion_events(Box::pin(
664                        response_lines(response).map_err(AnthropicError::Other),
665                    )))
666                });
667                async move { Ok(future.await?.boxed()) }.boxed()
668            }
669            CloudModel::OpenAi(model) => {
670                let client = self.client.clone();
671                let request = into_open_ai(request, model.id().into(), model.max_output_tokens());
672                let llm_api_token = self.llm_api_token.clone();
673                let future = self.request_limiter.stream(async move {
674                    let response = Self::perform_llm_completion(
675                        client.clone(),
676                        llm_api_token,
677                        PerformCompletionParams {
678                            provider: client::LanguageModelProvider::OpenAi,
679                            model: request.model.clone(),
680                            provider_request: RawValue::from_string(serde_json::to_string(
681                                &request,
682                            )?)?,
683                        },
684                    )
685                    .await?;
686                    Ok(open_ai::extract_text_from_events(response_lines(response)))
687                });
688                async move {
689                    Ok(future
690                        .await?
691                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
692                        .boxed())
693                }
694                .boxed()
695            }
696            CloudModel::Google(model) => {
697                let client = self.client.clone();
698                let request = into_google(request, model.id().into());
699                let llm_api_token = self.llm_api_token.clone();
700                let future = self.request_limiter.stream(async move {
701                    let response = Self::perform_llm_completion(
702                        client.clone(),
703                        llm_api_token,
704                        PerformCompletionParams {
705                            provider: client::LanguageModelProvider::Google,
706                            model: request.model.clone(),
707                            provider_request: RawValue::from_string(serde_json::to_string(
708                                &request,
709                            )?)?,
710                        },
711                    )
712                    .await?;
713                    Ok(google_ai::extract_text_from_events(response_lines(
714                        response,
715                    )))
716                });
717                async move {
718                    Ok(future
719                        .await?
720                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
721                        .boxed())
722                }
723                .boxed()
724            }
725        }
726    }
727
728    fn use_any_tool(
729        &self,
730        request: LanguageModelRequest,
731        tool_name: String,
732        tool_description: String,
733        input_schema: serde_json::Value,
734        _cx: &AsyncApp,
735    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
736        let client = self.client.clone();
737        let llm_api_token = self.llm_api_token.clone();
738
739        match &self.model {
740            CloudModel::Anthropic(model) => {
741                let mut request = into_anthropic(
742                    request,
743                    model.tool_model_id().into(),
744                    model.default_temperature(),
745                    model.max_output_tokens(),
746                );
747                request.tool_choice = Some(anthropic::ToolChoice::Tool {
748                    name: tool_name.clone(),
749                });
750                request.tools = vec![anthropic::Tool {
751                    name: tool_name.clone(),
752                    description: tool_description,
753                    input_schema,
754                }];
755
756                self.request_limiter
757                    .run(async move {
758                        let response = Self::perform_llm_completion(
759                            client.clone(),
760                            llm_api_token,
761                            PerformCompletionParams {
762                                provider: client::LanguageModelProvider::Anthropic,
763                                model: request.model.clone(),
764                                provider_request: RawValue::from_string(serde_json::to_string(
765                                    &request,
766                                )?)?,
767                            },
768                        )
769                        .await?;
770
771                        Ok(anthropic::extract_tool_args_from_events(
772                            tool_name,
773                            Box::pin(response_lines(response)),
774                        )
775                        .await?
776                        .boxed())
777                    })
778                    .boxed()
779            }
780            CloudModel::OpenAi(model) => {
781                let mut request =
782                    into_open_ai(request, model.id().into(), model.max_output_tokens());
783                request.tool_choice = Some(open_ai::ToolChoice::Other(
784                    open_ai::ToolDefinition::Function {
785                        function: open_ai::FunctionDefinition {
786                            name: tool_name.clone(),
787                            description: None,
788                            parameters: None,
789                        },
790                    },
791                ));
792                request.tools = vec![open_ai::ToolDefinition::Function {
793                    function: open_ai::FunctionDefinition {
794                        name: tool_name.clone(),
795                        description: Some(tool_description),
796                        parameters: Some(input_schema),
797                    },
798                }];
799
800                self.request_limiter
801                    .run(async move {
802                        let response = Self::perform_llm_completion(
803                            client.clone(),
804                            llm_api_token,
805                            PerformCompletionParams {
806                                provider: client::LanguageModelProvider::OpenAi,
807                                model: request.model.clone(),
808                                provider_request: RawValue::from_string(serde_json::to_string(
809                                    &request,
810                                )?)?,
811                            },
812                        )
813                        .await?;
814
815                        Ok(open_ai::extract_tool_args_from_events(
816                            tool_name,
817                            Box::pin(response_lines(response)),
818                        )
819                        .await?
820                        .boxed())
821                    })
822                    .boxed()
823            }
824            CloudModel::Google(_) => {
825                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
826            }
827        }
828    }
829}
830
831fn response_lines<T: DeserializeOwned>(
832    response: Response<AsyncBody>,
833) -> impl Stream<Item = Result<T>> {
834    futures::stream::try_unfold(
835        (String::new(), BufReader::new(response.into_body())),
836        move |(mut line, mut body)| async {
837            match body.read_line(&mut line).await {
838                Ok(0) => Ok(None),
839                Ok(_) => {
840                    let event: T = serde_json::from_str(&line)?;
841                    line.clear();
842                    Ok(Some((event, (line, body))))
843                }
844                Err(e) => Err(e.into()),
845            }
846        },
847    )
848}
849
850impl LlmApiToken {
851    pub async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
852        let lock = self.0.upgradable_read().await;
853        if let Some(token) = lock.as_ref() {
854            Ok(token.to_string())
855        } else {
856            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
857        }
858    }
859
860    pub async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
861        Self::fetch(self.0.write().await, client).await
862    }
863
864    async fn fetch<'a>(
865        mut lock: RwLockWriteGuard<'a, Option<String>>,
866        client: &Arc<Client>,
867    ) -> Result<String> {
868        let response = client.request(proto::GetLlmToken {}).await?;
869        *lock = Some(response.token.clone());
870        Ok(response.token.clone())
871    }
872}
873
874struct ConfigurationView {
875    state: gpui::Entity<State>,
876}
877
878impl ConfigurationView {
879    fn authenticate(&mut self, cx: &mut Context<Self>) {
880        self.state.update(cx, |state, cx| {
881            state.authenticate(cx).detach_and_log_err(cx);
882        });
883        cx.notify();
884    }
885}
886
887impl Render for ConfigurationView {
888    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
889        const ZED_AI_URL: &str = "https://zed.dev/ai";
890
891        let is_connected = !self.state.read(cx).is_signed_out();
892        let plan = self.state.read(cx).user_store.read(cx).current_plan();
893        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
894
895        let is_pro = plan == Some(proto::Plan::ZedPro);
896        let subscription_text = Label::new(if is_pro {
897            "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
898        } else {
899            "You have basic access to models from Anthropic through the Zed AI Free plan."
900        });
901        let manage_subscription_button = if is_pro {
902            Some(
903                h_flex().child(
904                    Button::new("manage_settings", "Manage Subscription")
905                        .style(ButtonStyle::Tinted(TintColor::Accent))
906                        .on_click(
907                            cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
908                        ),
909                ),
910            )
911        } else if cx.has_flag::<ZedPro>() {
912            Some(
913                h_flex()
914                    .gap_2()
915                    .child(
916                        Button::new("learn_more", "Learn more")
917                            .style(ButtonStyle::Subtle)
918                            .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_AI_URL))),
919                    )
920                    .child(
921                        Button::new("upgrade", "Upgrade")
922                            .style(ButtonStyle::Subtle)
923                            .color(Color::Accent)
924                            .on_click(
925                                cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
926                            ),
927                    ),
928            )
929        } else {
930            None
931        };
932
933        if is_connected {
934            v_flex()
935                .gap_3()
936                .w_full()
937                .children(render_accept_terms(
938                    self.state.clone(),
939                    LanguageModelProviderTosView::Configuration,
940                    cx,
941                ))
942                .when(has_accepted_terms, |this| {
943                    this.child(subscription_text)
944                        .children(manage_subscription_button)
945                })
946        } else {
947            v_flex()
948                .gap_2()
949                .child(Label::new("Use Zed AI to access hosted language models."))
950                .child(
951                    Button::new("sign_in", "Sign In")
952                        .icon_color(Color::Muted)
953                        .icon(IconName::Github)
954                        .icon_position(IconPosition::Start)
955                        .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
956                )
957        }
958    }
959}