cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::provider::anthropic::map_to_language_model_completion_events;
  3use crate::{
  4    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
  5    LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  6    LanguageModelProviderState, LanguageModelRequest, RateLimiter,
  7};
  8use anthropic::AnthropicError;
  9use anyhow::{anyhow, Result};
 10use client::{
 11    zed_urls, Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME,
 12    MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
 13};
 14use collections::BTreeMap;
 15use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
 16use futures::{
 17    future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
 18    TryStreamExt as _,
 19};
 20use gpui::{
 21    AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model,
 22    ModelContext, ReadGlobal, Subscription, Task,
 23};
 24use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
 25use proto::TypedEnvelope;
 26use schemars::JsonSchema;
 27use serde::{de::DeserializeOwned, Deserialize, Serialize};
 28use serde_json::value::RawValue;
 29use settings::{Settings, SettingsStore};
 30use smol::{
 31    io::{AsyncReadExt, BufReader},
 32    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 33};
 34use std::fmt;
 35use std::{
 36    future,
 37    sync::{Arc, LazyLock},
 38};
 39use strum::IntoEnumIterator;
 40use thiserror::Error;
 41use ui::{prelude::*, TintColor};
 42
 43use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
 44
 45use super::anthropic::count_anthropic_tokens;
 46
 47pub const PROVIDER_ID: &str = "zed.dev";
 48pub const PROVIDER_NAME: &str = "Zed";
 49
 50const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
 51    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
 52
 53fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
 54    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
 55        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
 56            .map(|json| serde_json::from_str(json).unwrap())
 57            .unwrap_or_default()
 58    });
 59    ADDITIONAL_MODELS.as_slice()
 60}
 61
 62#[derive(Default, Clone, Debug, PartialEq)]
 63pub struct ZedDotDevSettings {
 64    pub available_models: Vec<AvailableModel>,
 65}
 66
 67#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 68#[serde(rename_all = "lowercase")]
 69pub enum AvailableProvider {
 70    Anthropic,
 71    OpenAi,
 72    Google,
 73}
 74
 75#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 76pub struct AvailableModel {
 77    /// The provider of the language model.
 78    pub provider: AvailableProvider,
 79    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
 80    pub name: String,
 81    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 82    pub display_name: Option<String>,
 83    /// The size of the context window, indicating the maximum number of tokens the model can process.
 84    pub max_tokens: usize,
 85    /// The maximum number of output tokens allowed by the model.
 86    pub max_output_tokens: Option<u32>,
 87    /// The maximum number of completion tokens allowed by the model (o1-* only)
 88    pub max_completion_tokens: Option<u32>,
 89    /// Override this model with a different Anthropic model for tool calls.
 90    pub tool_override: Option<String>,
 91    /// Indicates whether this custom model supports caching.
 92    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 93    /// The default temperature to use for this model.
 94    pub default_temperature: Option<f32>,
 95}
 96
 97struct GlobalRefreshLlmTokenListener(Model<RefreshLlmTokenListener>);
 98
 99impl Global for GlobalRefreshLlmTokenListener {}
100
101pub struct RefreshLlmTokenEvent;
102
103pub struct RefreshLlmTokenListener {
104    _llm_token_subscription: client::Subscription,
105}
106
107impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
108
109impl RefreshLlmTokenListener {
110    pub fn register(client: Arc<Client>, cx: &mut AppContext) {
111        let listener = cx.new_model(|cx| RefreshLlmTokenListener::new(client, cx));
112        cx.set_global(GlobalRefreshLlmTokenListener(listener));
113    }
114
115    pub fn global(cx: &AppContext) -> Model<Self> {
116        GlobalRefreshLlmTokenListener::global(cx).0.clone()
117    }
118
119    fn new(client: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
120        Self {
121            _llm_token_subscription: client
122                .add_message_handler(cx.weak_model(), Self::handle_refresh_llm_token),
123        }
124    }
125
126    async fn handle_refresh_llm_token(
127        this: Model<Self>,
128        _: TypedEnvelope<proto::RefreshLlmToken>,
129        mut cx: AsyncAppContext,
130    ) -> Result<()> {
131        this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent))
132    }
133}
134
135pub struct CloudLanguageModelProvider {
136    client: Arc<Client>,
137    state: gpui::Model<State>,
138    _maintain_client_status: Task<()>,
139}
140
141pub struct State {
142    client: Arc<Client>,
143    llm_api_token: LlmApiToken,
144    user_store: Model<UserStore>,
145    status: client::Status,
146    accept_terms: Option<Task<Result<()>>>,
147    _settings_subscription: Subscription,
148    _llm_token_subscription: Subscription,
149}
150
151impl State {
152    fn new(
153        client: Arc<Client>,
154        user_store: Model<UserStore>,
155        status: client::Status,
156        cx: &mut ModelContext<Self>,
157    ) -> Self {
158        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
159
160        Self {
161            client: client.clone(),
162            llm_api_token: LlmApiToken::default(),
163            user_store,
164            status,
165            accept_terms: None,
166            _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
167                cx.notify();
168            }),
169            _llm_token_subscription: cx.subscribe(
170                &refresh_llm_token_listener,
171                |this, _listener, _event, cx| {
172                    let client = this.client.clone();
173                    let llm_api_token = this.llm_api_token.clone();
174                    cx.spawn(|_this, _cx| async move {
175                        llm_api_token.refresh(&client).await?;
176                        anyhow::Ok(())
177                    })
178                    .detach_and_log_err(cx);
179                },
180            ),
181        }
182    }
183
184    fn is_signed_out(&self) -> bool {
185        self.status.is_signed_out()
186    }
187
188    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
189        let client = self.client.clone();
190        cx.spawn(move |this, mut cx| async move {
191            client.authenticate_and_connect(true, &cx).await?;
192            this.update(&mut cx, |_, cx| cx.notify())
193        })
194    }
195
196    fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
197        self.user_store
198            .read(cx)
199            .current_user_has_accepted_terms()
200            .unwrap_or(false)
201    }
202
203    fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
204        let user_store = self.user_store.clone();
205        self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
206            let _ = user_store
207                .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
208                .await;
209            this.update(&mut cx, |this, cx| {
210                this.accept_terms = None;
211                cx.notify()
212            })
213        }));
214    }
215}
216
217impl CloudLanguageModelProvider {
218    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
219        let mut status_rx = client.status();
220        let status = *status_rx.borrow();
221
222        let state = cx.new_model(|cx| State::new(client.clone(), user_store.clone(), status, cx));
223
224        let state_ref = state.downgrade();
225        let maintain_client_status = cx.spawn(|mut cx| async move {
226            while let Some(status) = status_rx.next().await {
227                if let Some(this) = state_ref.upgrade() {
228                    _ = this.update(&mut cx, |this, cx| {
229                        if this.status != status {
230                            this.status = status;
231                            cx.notify();
232                        }
233                    });
234                } else {
235                    break;
236                }
237            }
238        });
239
240        Self {
241            client,
242            state: state.clone(),
243            _maintain_client_status: maintain_client_status,
244        }
245    }
246}
247
248impl LanguageModelProviderState for CloudLanguageModelProvider {
249    type ObservableEntity = State;
250
251    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
252        Some(self.state.clone())
253    }
254}
255
256impl LanguageModelProvider for CloudLanguageModelProvider {
257    fn id(&self) -> LanguageModelProviderId {
258        LanguageModelProviderId(PROVIDER_ID.into())
259    }
260
261    fn name(&self) -> LanguageModelProviderName {
262        LanguageModelProviderName(PROVIDER_NAME.into())
263    }
264
265    fn icon(&self) -> IconName {
266        IconName::AiZed
267    }
268
269    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
270        let mut models = BTreeMap::default();
271
272        if cx.is_staff() {
273            for model in anthropic::Model::iter() {
274                if !matches!(model, anthropic::Model::Custom { .. }) {
275                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
276                }
277            }
278            for model in open_ai::Model::iter() {
279                if !matches!(model, open_ai::Model::Custom { .. }) {
280                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
281                }
282            }
283            for model in google_ai::Model::iter() {
284                if !matches!(model, google_ai::Model::Custom { .. }) {
285                    models.insert(model.id().to_string(), CloudModel::Google(model));
286                }
287            }
288        } else {
289            models.insert(
290                anthropic::Model::Claude3_5Sonnet.id().to_string(),
291                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
292            );
293        }
294
295        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
296            zed_cloud_provider_additional_models()
297        } else {
298            &[]
299        };
300
301        // Override with available models from settings
302        for model in AllLanguageModelSettings::get_global(cx)
303            .zed_dot_dev
304            .available_models
305            .iter()
306            .chain(llm_closed_beta_models)
307            .cloned()
308        {
309            let model = match model.provider {
310                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
311                    name: model.name.clone(),
312                    display_name: model.display_name.clone(),
313                    max_tokens: model.max_tokens,
314                    tool_override: model.tool_override.clone(),
315                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
316                        anthropic::AnthropicModelCacheConfiguration {
317                            max_cache_anchors: config.max_cache_anchors,
318                            should_speculate: config.should_speculate,
319                            min_total_token: config.min_total_token,
320                        }
321                    }),
322                    default_temperature: model.default_temperature,
323                    max_output_tokens: model.max_output_tokens,
324                }),
325                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
326                    name: model.name.clone(),
327                    display_name: model.display_name.clone(),
328                    max_tokens: model.max_tokens,
329                    max_output_tokens: model.max_output_tokens,
330                    max_completion_tokens: model.max_completion_tokens,
331                }),
332                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
333                    name: model.name.clone(),
334                    display_name: model.display_name.clone(),
335                    max_tokens: model.max_tokens,
336                }),
337            };
338            models.insert(model.id().to_string(), model.clone());
339        }
340
341        let llm_api_token = self.state.read(cx).llm_api_token.clone();
342        models
343            .into_values()
344            .map(|model| {
345                Arc::new(CloudLanguageModel {
346                    id: LanguageModelId::from(model.id().to_string()),
347                    model,
348                    llm_api_token: llm_api_token.clone(),
349                    client: self.client.clone(),
350                    request_limiter: RateLimiter::new(4),
351                }) as Arc<dyn LanguageModel>
352            })
353            .collect()
354    }
355
356    fn is_authenticated(&self, cx: &AppContext) -> bool {
357        !self.state.read(cx).is_signed_out()
358    }
359
360    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
361        Task::ready(Ok(()))
362    }
363
364    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
365        cx.new_view(|_cx| ConfigurationView {
366            state: self.state.clone(),
367        })
368        .into()
369    }
370
371    fn must_accept_terms(&self, cx: &AppContext) -> bool {
372        !self.state.read(cx).has_accepted_terms_of_service(cx)
373    }
374
375    fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
376        let state = self.state.read(cx);
377
378        let terms = [(
379            "terms_of_service",
380            "Terms of Service",
381            "https://zed.dev/terms-of-service",
382        )]
383        .map(|(id, label, url)| {
384            Button::new(id, label)
385                .style(ButtonStyle::Subtle)
386                .icon(IconName::ExternalLink)
387                .icon_size(IconSize::XSmall)
388                .icon_color(Color::Muted)
389                .on_click(move |_, cx| cx.open_url(url))
390        });
391
392        if state.has_accepted_terms_of_service(cx) {
393            None
394        } else {
395            let disabled = state.accept_terms.is_some();
396            Some(
397                v_flex()
398                    .gap_2()
399                    .child(
400                        v_flex()
401                            .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
402                            .child(
403                                Label::new(
404                                    "Please read and accept our terms and conditions to continue.",
405                                )
406                                .size(LabelSize::Small),
407                            ),
408                    )
409                    .child(v_flex().gap_1().children(terms))
410                    .child(
411                        h_flex().justify_end().child(
412                            Button::new("accept_terms", "I've read it and accept it")
413                                .disabled(disabled)
414                                .on_click({
415                                    let state = self.state.downgrade();
416                                    move |_, cx| {
417                                        state
418                                            .update(cx, |state, cx| {
419                                                state.accept_terms_of_service(cx)
420                                            })
421                                            .ok();
422                                    }
423                                }),
424                        ),
425                    )
426                    .into_any(),
427            )
428        }
429    }
430
431    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
432        Task::ready(Ok(()))
433    }
434}
435
436pub struct CloudLanguageModel {
437    id: LanguageModelId,
438    model: CloudModel,
439    llm_api_token: LlmApiToken,
440    client: Arc<Client>,
441    request_limiter: RateLimiter,
442}
443
444#[derive(Clone, Default)]
445struct LlmApiToken(Arc<RwLock<Option<String>>>);
446
447#[derive(Error, Debug)]
448pub struct PaymentRequiredError;
449
450impl fmt::Display for PaymentRequiredError {
451    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
452        write!(
453            f,
454            "Payment required to use this language model. Please upgrade your account."
455        )
456    }
457}
458
459#[derive(Error, Debug)]
460pub struct MaxMonthlySpendReachedError;
461
462impl fmt::Display for MaxMonthlySpendReachedError {
463    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
464        write!(
465            f,
466            "Maximum spending limit reached for this month. For more usage, increase your spending limit."
467        )
468    }
469}
470
471impl CloudLanguageModel {
472    async fn perform_llm_completion(
473        client: Arc<Client>,
474        llm_api_token: LlmApiToken,
475        body: PerformCompletionParams,
476    ) -> Result<Response<AsyncBody>> {
477        let http_client = &client.http_client();
478
479        let mut token = llm_api_token.acquire(&client).await?;
480        let mut did_retry = false;
481
482        let response = loop {
483            let request_builder = http_client::Request::builder();
484            let request = request_builder
485                .method(Method::POST)
486                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
487                .header("Content-Type", "application/json")
488                .header("Authorization", format!("Bearer {token}"))
489                .body(serde_json::to_string(&body)?.into())?;
490            let mut response = http_client.send(request).await?;
491            if response.status().is_success() {
492                break response;
493            } else if !did_retry
494                && response
495                    .headers()
496                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
497                    .is_some()
498            {
499                did_retry = true;
500                token = llm_api_token.refresh(&client).await?;
501            } else if response.status() == StatusCode::FORBIDDEN
502                && response
503                    .headers()
504                    .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
505                    .is_some()
506            {
507                break Err(anyhow!(MaxMonthlySpendReachedError))?;
508            } else if response.status() == StatusCode::PAYMENT_REQUIRED {
509                break Err(anyhow!(PaymentRequiredError))?;
510            } else {
511                let mut body = String::new();
512                response.body_mut().read_to_string(&mut body).await?;
513                break Err(anyhow!(
514                    "cloud language model completion failed with status {}: {body}",
515                    response.status()
516                ))?;
517            }
518        };
519
520        Ok(response)
521    }
522}
523
524impl LanguageModel for CloudLanguageModel {
525    fn id(&self) -> LanguageModelId {
526        self.id.clone()
527    }
528
529    fn name(&self) -> LanguageModelName {
530        LanguageModelName::from(self.model.display_name().to_string())
531    }
532
533    fn icon(&self) -> Option<IconName> {
534        self.model.icon()
535    }
536
537    fn provider_id(&self) -> LanguageModelProviderId {
538        LanguageModelProviderId(PROVIDER_ID.into())
539    }
540
541    fn provider_name(&self) -> LanguageModelProviderName {
542        LanguageModelProviderName(PROVIDER_NAME.into())
543    }
544
545    fn telemetry_id(&self) -> String {
546        format!("zed.dev/{}", self.model.id())
547    }
548
549    fn availability(&self) -> LanguageModelAvailability {
550        self.model.availability()
551    }
552
553    fn max_token_count(&self) -> usize {
554        self.model.max_token_count()
555    }
556
557    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
558        match &self.model {
559            CloudModel::Anthropic(model) => {
560                model
561                    .cache_configuration()
562                    .map(|cache| LanguageModelCacheConfiguration {
563                        max_cache_anchors: cache.max_cache_anchors,
564                        should_speculate: cache.should_speculate,
565                        min_total_token: cache.min_total_token,
566                    })
567            }
568            CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
569        }
570    }
571
572    fn count_tokens(
573        &self,
574        request: LanguageModelRequest,
575        cx: &AppContext,
576    ) -> BoxFuture<'static, Result<usize>> {
577        match self.model.clone() {
578            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
579            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
580            CloudModel::Google(model) => {
581                let client = self.client.clone();
582                let request = request.into_google(model.id().into());
583                let request = google_ai::CountTokensRequest {
584                    contents: request.contents,
585                };
586                async move {
587                    let request = serde_json::to_string(&request)?;
588                    let response = client
589                        .request(proto::CountLanguageModelTokens {
590                            provider: proto::LanguageModelProvider::Google as i32,
591                            request,
592                        })
593                        .await?;
594                    Ok(response.token_count as usize)
595                }
596                .boxed()
597            }
598        }
599    }
600
601    fn stream_completion(
602        &self,
603        request: LanguageModelRequest,
604        _cx: &AsyncAppContext,
605    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
606        match &self.model {
607            CloudModel::Anthropic(model) => {
608                let request = request.into_anthropic(
609                    model.id().into(),
610                    model.default_temperature(),
611                    model.max_output_tokens(),
612                );
613                let client = self.client.clone();
614                let llm_api_token = self.llm_api_token.clone();
615                let future = self.request_limiter.stream(async move {
616                    let response = Self::perform_llm_completion(
617                        client.clone(),
618                        llm_api_token,
619                        PerformCompletionParams {
620                            provider: client::LanguageModelProvider::Anthropic,
621                            model: request.model.clone(),
622                            provider_request: RawValue::from_string(serde_json::to_string(
623                                &request,
624                            )?)?,
625                        },
626                    )
627                    .await?;
628                    Ok(map_to_language_model_completion_events(Box::pin(
629                        response_lines(response).map_err(AnthropicError::Other),
630                    )))
631                });
632                async move { Ok(future.await?.boxed()) }.boxed()
633            }
634            CloudModel::OpenAi(model) => {
635                let client = self.client.clone();
636                let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
637                let llm_api_token = self.llm_api_token.clone();
638                let future = self.request_limiter.stream(async move {
639                    let response = Self::perform_llm_completion(
640                        client.clone(),
641                        llm_api_token,
642                        PerformCompletionParams {
643                            provider: client::LanguageModelProvider::OpenAi,
644                            model: request.model.clone(),
645                            provider_request: RawValue::from_string(serde_json::to_string(
646                                &request,
647                            )?)?,
648                        },
649                    )
650                    .await?;
651                    Ok(open_ai::extract_text_from_events(response_lines(response)))
652                });
653                async move {
654                    Ok(future
655                        .await?
656                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
657                        .boxed())
658                }
659                .boxed()
660            }
661            CloudModel::Google(model) => {
662                let client = self.client.clone();
663                let request = request.into_google(model.id().into());
664                let llm_api_token = self.llm_api_token.clone();
665                let future = self.request_limiter.stream(async move {
666                    let response = Self::perform_llm_completion(
667                        client.clone(),
668                        llm_api_token,
669                        PerformCompletionParams {
670                            provider: client::LanguageModelProvider::Google,
671                            model: request.model.clone(),
672                            provider_request: RawValue::from_string(serde_json::to_string(
673                                &request,
674                            )?)?,
675                        },
676                    )
677                    .await?;
678                    Ok(google_ai::extract_text_from_events(response_lines(
679                        response,
680                    )))
681                });
682                async move {
683                    Ok(future
684                        .await?
685                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
686                        .boxed())
687                }
688                .boxed()
689            }
690        }
691    }
692
693    fn use_any_tool(
694        &self,
695        request: LanguageModelRequest,
696        tool_name: String,
697        tool_description: String,
698        input_schema: serde_json::Value,
699        _cx: &AsyncAppContext,
700    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
701        let client = self.client.clone();
702        let llm_api_token = self.llm_api_token.clone();
703
704        match &self.model {
705            CloudModel::Anthropic(model) => {
706                let mut request = request.into_anthropic(
707                    model.tool_model_id().into(),
708                    model.default_temperature(),
709                    model.max_output_tokens(),
710                );
711                request.tool_choice = Some(anthropic::ToolChoice::Tool {
712                    name: tool_name.clone(),
713                });
714                request.tools = vec![anthropic::Tool {
715                    name: tool_name.clone(),
716                    description: tool_description,
717                    input_schema,
718                }];
719
720                self.request_limiter
721                    .run(async move {
722                        let response = Self::perform_llm_completion(
723                            client.clone(),
724                            llm_api_token,
725                            PerformCompletionParams {
726                                provider: client::LanguageModelProvider::Anthropic,
727                                model: request.model.clone(),
728                                provider_request: RawValue::from_string(serde_json::to_string(
729                                    &request,
730                                )?)?,
731                            },
732                        )
733                        .await?;
734
735                        Ok(anthropic::extract_tool_args_from_events(
736                            tool_name,
737                            Box::pin(response_lines(response)),
738                        )
739                        .await?
740                        .boxed())
741                    })
742                    .boxed()
743            }
744            CloudModel::OpenAi(model) => {
745                let mut request =
746                    request.into_open_ai(model.id().into(), model.max_output_tokens());
747                request.tool_choice = Some(open_ai::ToolChoice::Other(
748                    open_ai::ToolDefinition::Function {
749                        function: open_ai::FunctionDefinition {
750                            name: tool_name.clone(),
751                            description: None,
752                            parameters: None,
753                        },
754                    },
755                ));
756                request.tools = vec![open_ai::ToolDefinition::Function {
757                    function: open_ai::FunctionDefinition {
758                        name: tool_name.clone(),
759                        description: Some(tool_description),
760                        parameters: Some(input_schema),
761                    },
762                }];
763
764                self.request_limiter
765                    .run(async move {
766                        let response = Self::perform_llm_completion(
767                            client.clone(),
768                            llm_api_token,
769                            PerformCompletionParams {
770                                provider: client::LanguageModelProvider::OpenAi,
771                                model: request.model.clone(),
772                                provider_request: RawValue::from_string(serde_json::to_string(
773                                    &request,
774                                )?)?,
775                            },
776                        )
777                        .await?;
778
779                        Ok(open_ai::extract_tool_args_from_events(
780                            tool_name,
781                            Box::pin(response_lines(response)),
782                        )
783                        .await?
784                        .boxed())
785                    })
786                    .boxed()
787            }
788            CloudModel::Google(_) => {
789                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
790            }
791        }
792    }
793}
794
795fn response_lines<T: DeserializeOwned>(
796    response: Response<AsyncBody>,
797) -> impl Stream<Item = Result<T>> {
798    futures::stream::try_unfold(
799        (String::new(), BufReader::new(response.into_body())),
800        move |(mut line, mut body)| async {
801            match body.read_line(&mut line).await {
802                Ok(0) => Ok(None),
803                Ok(_) => {
804                    let event: T = serde_json::from_str(&line)?;
805                    line.clear();
806                    Ok(Some((event, (line, body))))
807                }
808                Err(e) => Err(e.into()),
809            }
810        },
811    )
812}
813
814impl LlmApiToken {
815    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
816        let lock = self.0.upgradable_read().await;
817        if let Some(token) = lock.as_ref() {
818            Ok(token.to_string())
819        } else {
820            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
821        }
822    }
823
824    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
825        Self::fetch(self.0.write().await, client).await
826    }
827
828    async fn fetch<'a>(
829        mut lock: RwLockWriteGuard<'a, Option<String>>,
830        client: &Arc<Client>,
831    ) -> Result<String> {
832        let response = client.request(proto::GetLlmToken {}).await?;
833        *lock = Some(response.token.clone());
834        Ok(response.token.clone())
835    }
836}
837
838struct ConfigurationView {
839    state: gpui::Model<State>,
840}
841
842impl ConfigurationView {
843    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
844        self.state.update(cx, |state, cx| {
845            state.authenticate(cx).detach_and_log_err(cx);
846        });
847        cx.notify();
848    }
849
850    fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
851        if self.state.read(cx).has_accepted_terms_of_service(cx) {
852            return None;
853        }
854
855        let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
856
857        let terms_button = Button::new("terms_of_service", "Terms of Service")
858            .style(ButtonStyle::Subtle)
859            .icon(IconName::ExternalLink)
860            .icon_color(Color::Muted)
861            .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
862
863        let text =
864            "In order to use Zed AI, please read and accept our terms and conditions to continue:";
865
866        let form = v_flex()
867            .gap_2()
868            .child(Label::new("Terms and Conditions"))
869            .child(Label::new(text))
870            .child(h_flex().justify_center().child(terms_button))
871            .child(
872                h_flex().justify_center().child(
873                    Button::new("accept_terms", "I've read and accept the terms of service")
874                        .style(ButtonStyle::Tinted(TintColor::Accent))
875                        .disabled(accept_terms_disabled)
876                        .on_click({
877                            let state = self.state.downgrade();
878                            move |_, cx| {
879                                state
880                                    .update(cx, |state, cx| state.accept_terms_of_service(cx))
881                                    .ok();
882                            }
883                        }),
884                ),
885            );
886
887        Some(form.into_any())
888    }
889}
890
891impl Render for ConfigurationView {
892    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
893        const ZED_AI_URL: &str = "https://zed.dev/ai";
894
895        let is_connected = !self.state.read(cx).is_signed_out();
896        let plan = self.state.read(cx).user_store.read(cx).current_plan();
897        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
898
899        let is_pro = plan == Some(proto::Plan::ZedPro);
900        let subscription_text = Label::new(if is_pro {
901            "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
902        } else {
903            "You have basic access to models from Anthropic through the Zed AI Free plan."
904        });
905        let manage_subscription_button = if is_pro {
906            Some(
907                h_flex().child(
908                    Button::new("manage_settings", "Manage Subscription")
909                        .style(ButtonStyle::Tinted(TintColor::Accent))
910                        .on_click(cx.listener(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))),
911                ),
912            )
913        } else if cx.has_flag::<ZedPro>() {
914            Some(
915                h_flex()
916                    .gap_2()
917                    .child(
918                        Button::new("learn_more", "Learn more")
919                            .style(ButtonStyle::Subtle)
920                            .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
921                    )
922                    .child(
923                        Button::new("upgrade", "Upgrade")
924                            .style(ButtonStyle::Subtle)
925                            .color(Color::Accent)
926                            .on_click(
927                                cx.listener(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
928                            ),
929                    ),
930            )
931        } else {
932            None
933        };
934
935        if is_connected {
936            v_flex()
937                .gap_3()
938                .max_w_4_5()
939                .children(self.render_accept_terms(cx))
940                .when(has_accepted_terms, |this| {
941                    this.child(subscription_text)
942                        .children(manage_subscription_button)
943                })
944        } else {
945            v_flex()
946                .gap_2()
947                .child(Label::new("Use Zed AI to access hosted language models."))
948                .child(
949                    Button::new("sign_in", "Sign In")
950                        .icon_color(Color::Muted)
951                        .icon(IconName::Github)
952                        .icon_position(IconPosition::Start)
953                        .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
954                )
955        }
956    }
957}