cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::provider::anthropic::map_to_language_model_completion_events;
  3use crate::{
  4    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
  5    LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  6    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  7};
  8use anthropic::AnthropicError;
  9use anyhow::{anyhow, Result};
 10use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
 11use collections::BTreeMap;
 12use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
 13use futures::{
 14    future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
 15    TryStreamExt as _,
 16};
 17use gpui::{
 18    AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
 19    Subscription, Task,
 20};
 21use http_client::{AsyncBody, HttpClient, Method, Response};
 22use isahc::config::Configurable;
 23use schemars::JsonSchema;
 24use serde::{de::DeserializeOwned, Deserialize, Serialize};
 25use serde_json::value::RawValue;
 26use settings::{Settings, SettingsStore};
 27use smol::{
 28    io::{AsyncReadExt, BufReader},
 29    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 30};
 31use std::time::Duration;
 32use std::{
 33    future,
 34    sync::{Arc, LazyLock},
 35};
 36use strum::IntoEnumIterator;
 37use ui::{prelude::*, TintColor};
 38
 39use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
 40
 41use super::anthropic::count_anthropic_tokens;
 42
 43pub const PROVIDER_ID: &str = "zed.dev";
 44pub const PROVIDER_NAME: &str = "Zed";
 45
 46const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
 47    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
 48
 49fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
 50    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
 51        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
 52            .map(|json| serde_json::from_str(json).unwrap())
 53            .unwrap_or_default()
 54    });
 55    ADDITIONAL_MODELS.as_slice()
 56}
 57
 58#[derive(Default, Clone, Debug, PartialEq)]
 59pub struct ZedDotDevSettings {
 60    pub available_models: Vec<AvailableModel>,
 61    pub low_speed_timeout: Option<Duration>,
 62}
 63
 64#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 65#[serde(rename_all = "lowercase")]
 66pub enum AvailableProvider {
 67    Anthropic,
 68    OpenAi,
 69    Google,
 70}
 71
 72#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 73pub struct AvailableModel {
 74    /// The provider of the language model.
 75    pub provider: AvailableProvider,
 76    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
 77    pub name: String,
 78    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 79    pub display_name: Option<String>,
 80    /// The size of the context window, indicating the maximum number of tokens the model can process.
 81    pub max_tokens: usize,
 82    /// The maximum number of output tokens allowed by the model.
 83    pub max_output_tokens: Option<u32>,
 84    /// The maximum number of completion tokens allowed by the model (o1-* only)
 85    pub max_completion_tokens: Option<u32>,
 86    /// Override this model with a different Anthropic model for tool calls.
 87    pub tool_override: Option<String>,
 88    /// Indicates whether this custom model supports caching.
 89    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 90    /// The default temperature to use for this model.
 91    pub default_temperature: Option<f32>,
 92}
 93
 94pub struct CloudLanguageModelProvider {
 95    client: Arc<Client>,
 96    llm_api_token: LlmApiToken,
 97    state: gpui::Model<State>,
 98    _maintain_client_status: Task<()>,
 99}
100
101pub struct State {
102    client: Arc<Client>,
103    user_store: Model<UserStore>,
104    status: client::Status,
105    accept_terms: Option<Task<Result<()>>>,
106    _subscription: Subscription,
107}
108
109impl State {
110    fn is_signed_out(&self) -> bool {
111        self.status.is_signed_out()
112    }
113
114    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
115        let client = self.client.clone();
116        cx.spawn(move |this, mut cx| async move {
117            client.authenticate_and_connect(true, &cx).await?;
118            this.update(&mut cx, |_, cx| cx.notify())
119        })
120    }
121
122    fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
123        self.user_store
124            .read(cx)
125            .current_user_has_accepted_terms()
126            .unwrap_or(false)
127    }
128
129    fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
130        let user_store = self.user_store.clone();
131        self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
132            let _ = user_store
133                .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
134                .await;
135            this.update(&mut cx, |this, cx| {
136                this.accept_terms = None;
137                cx.notify()
138            })
139        }));
140    }
141}
142
143impl CloudLanguageModelProvider {
144    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
145        let mut status_rx = client.status();
146        let status = *status_rx.borrow();
147
148        let state = cx.new_model(|cx| State {
149            client: client.clone(),
150            user_store,
151            status,
152            accept_terms: None,
153            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
154                cx.notify();
155            }),
156        });
157
158        let state_ref = state.downgrade();
159        let maintain_client_status = cx.spawn(|mut cx| async move {
160            while let Some(status) = status_rx.next().await {
161                if let Some(this) = state_ref.upgrade() {
162                    _ = this.update(&mut cx, |this, cx| {
163                        if this.status != status {
164                            this.status = status;
165                            cx.notify();
166                        }
167                    });
168                } else {
169                    break;
170                }
171            }
172        });
173
174        Self {
175            client,
176            state,
177            llm_api_token: LlmApiToken::default(),
178            _maintain_client_status: maintain_client_status,
179        }
180    }
181}
182
183impl LanguageModelProviderState for CloudLanguageModelProvider {
184    type ObservableEntity = State;
185
186    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
187        Some(self.state.clone())
188    }
189}
190
191impl LanguageModelProvider for CloudLanguageModelProvider {
192    fn id(&self) -> LanguageModelProviderId {
193        LanguageModelProviderId(PROVIDER_ID.into())
194    }
195
196    fn name(&self) -> LanguageModelProviderName {
197        LanguageModelProviderName(PROVIDER_NAME.into())
198    }
199
200    fn icon(&self) -> IconName {
201        IconName::AiZed
202    }
203
204    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
205        let mut models = BTreeMap::default();
206
207        if cx.is_staff() {
208            for model in anthropic::Model::iter() {
209                if !matches!(model, anthropic::Model::Custom { .. }) {
210                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
211                }
212            }
213            for model in open_ai::Model::iter() {
214                if !matches!(model, open_ai::Model::Custom { .. }) {
215                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
216                }
217            }
218            for model in google_ai::Model::iter() {
219                if !matches!(model, google_ai::Model::Custom { .. }) {
220                    models.insert(model.id().to_string(), CloudModel::Google(model));
221                }
222            }
223            for model in ZedModel::iter() {
224                models.insert(model.id().to_string(), CloudModel::Zed(model));
225            }
226        } else {
227            models.insert(
228                anthropic::Model::Claude3_5Sonnet.id().to_string(),
229                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
230            );
231        }
232
233        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
234            zed_cloud_provider_additional_models()
235        } else {
236            &[]
237        };
238
239        // Override with available models from settings
240        for model in AllLanguageModelSettings::get_global(cx)
241            .zed_dot_dev
242            .available_models
243            .iter()
244            .chain(llm_closed_beta_models)
245            .cloned()
246        {
247            let model = match model.provider {
248                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
249                    name: model.name.clone(),
250                    display_name: model.display_name.clone(),
251                    max_tokens: model.max_tokens,
252                    tool_override: model.tool_override.clone(),
253                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
254                        anthropic::AnthropicModelCacheConfiguration {
255                            max_cache_anchors: config.max_cache_anchors,
256                            should_speculate: config.should_speculate,
257                            min_total_token: config.min_total_token,
258                        }
259                    }),
260                    default_temperature: model.default_temperature,
261                    max_output_tokens: model.max_output_tokens,
262                }),
263                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
264                    name: model.name.clone(),
265                    display_name: model.display_name.clone(),
266                    max_tokens: model.max_tokens,
267                    max_output_tokens: model.max_output_tokens,
268                    max_completion_tokens: model.max_completion_tokens,
269                }),
270                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
271                    name: model.name.clone(),
272                    display_name: model.display_name.clone(),
273                    max_tokens: model.max_tokens,
274                }),
275            };
276            models.insert(model.id().to_string(), model.clone());
277        }
278
279        models
280            .into_values()
281            .map(|model| {
282                Arc::new(CloudLanguageModel {
283                    id: LanguageModelId::from(model.id().to_string()),
284                    model,
285                    llm_api_token: self.llm_api_token.clone(),
286                    client: self.client.clone(),
287                    request_limiter: RateLimiter::new(4),
288                }) as Arc<dyn LanguageModel>
289            })
290            .collect()
291    }
292
293    fn is_authenticated(&self, cx: &AppContext) -> bool {
294        !self.state.read(cx).is_signed_out()
295    }
296
297    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
298        Task::ready(Ok(()))
299    }
300
301    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
302        cx.new_view(|_cx| ConfigurationView {
303            state: self.state.clone(),
304        })
305        .into()
306    }
307
308    fn must_accept_terms(&self, cx: &AppContext) -> bool {
309        !self.state.read(cx).has_accepted_terms_of_service(cx)
310    }
311
312    fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
313        let state = self.state.read(cx);
314
315        let terms = [(
316            "terms_of_service",
317            "Terms of Service",
318            "https://zed.dev/terms-of-service",
319        )]
320        .map(|(id, label, url)| {
321            Button::new(id, label)
322                .style(ButtonStyle::Subtle)
323                .icon(IconName::ExternalLink)
324                .icon_size(IconSize::XSmall)
325                .icon_color(Color::Muted)
326                .on_click(move |_, cx| cx.open_url(url))
327        });
328
329        if state.has_accepted_terms_of_service(cx) {
330            None
331        } else {
332            let disabled = state.accept_terms.is_some();
333            Some(
334                v_flex()
335                    .gap_2()
336                    .child(
337                        v_flex()
338                            .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
339                            .child(
340                                Label::new(
341                                    "Please read and accept our terms and conditions to continue.",
342                                )
343                                .size(LabelSize::Small),
344                            ),
345                    )
346                    .child(v_flex().gap_1().children(terms))
347                    .child(
348                        h_flex().justify_end().child(
349                            Button::new("accept_terms", "I've read it and accept it")
350                                .disabled(disabled)
351                                .on_click({
352                                    let state = self.state.downgrade();
353                                    move |_, cx| {
354                                        state
355                                            .update(cx, |state, cx| {
356                                                state.accept_terms_of_service(cx)
357                                            })
358                                            .ok();
359                                    }
360                                }),
361                        ),
362                    )
363                    .into_any(),
364            )
365        }
366    }
367
368    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
369        Task::ready(Ok(()))
370    }
371}
372
373pub struct CloudLanguageModel {
374    id: LanguageModelId,
375    model: CloudModel,
376    llm_api_token: LlmApiToken,
377    client: Arc<Client>,
378    request_limiter: RateLimiter,
379}
380
381#[derive(Clone, Default)]
382struct LlmApiToken(Arc<RwLock<Option<String>>>);
383
384impl CloudLanguageModel {
385    async fn perform_llm_completion(
386        client: Arc<Client>,
387        llm_api_token: LlmApiToken,
388        body: PerformCompletionParams,
389        low_speed_timeout: Option<Duration>,
390    ) -> Result<Response<AsyncBody>> {
391        let http_client = &client.http_client();
392
393        let mut token = llm_api_token.acquire(&client).await?;
394        let mut did_retry = false;
395
396        let response = loop {
397            let mut request_builder = http_client::Request::builder();
398            if let Some(low_speed_timeout) = low_speed_timeout {
399                request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
400            };
401            let request = request_builder
402                .method(Method::POST)
403                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
404                .header("Content-Type", "application/json")
405                .header("Authorization", format!("Bearer {token}"))
406                .body(serde_json::to_string(&body)?.into())?;
407            let mut response = http_client.send(request).await?;
408            if response.status().is_success() {
409                break response;
410            } else if !did_retry
411                && response
412                    .headers()
413                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
414                    .is_some()
415            {
416                did_retry = true;
417                token = llm_api_token.refresh(&client).await?;
418            } else {
419                let mut body = String::new();
420                response.body_mut().read_to_string(&mut body).await?;
421                break Err(anyhow!(
422                    "cloud language model completion failed with status {}: {body}",
423                    response.status()
424                ))?;
425            }
426        };
427
428        Ok(response)
429    }
430}
431
432impl LanguageModel for CloudLanguageModel {
433    fn id(&self) -> LanguageModelId {
434        self.id.clone()
435    }
436
437    fn name(&self) -> LanguageModelName {
438        LanguageModelName::from(self.model.display_name().to_string())
439    }
440
441    fn icon(&self) -> Option<IconName> {
442        self.model.icon()
443    }
444
445    fn provider_id(&self) -> LanguageModelProviderId {
446        LanguageModelProviderId(PROVIDER_ID.into())
447    }
448
449    fn provider_name(&self) -> LanguageModelProviderName {
450        LanguageModelProviderName(PROVIDER_NAME.into())
451    }
452
453    fn telemetry_id(&self) -> String {
454        format!("zed.dev/{}", self.model.id())
455    }
456
457    fn availability(&self) -> LanguageModelAvailability {
458        self.model.availability()
459    }
460
461    fn max_token_count(&self) -> usize {
462        self.model.max_token_count()
463    }
464
465    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
466        match &self.model {
467            CloudModel::Anthropic(model) => {
468                model
469                    .cache_configuration()
470                    .map(|cache| LanguageModelCacheConfiguration {
471                        max_cache_anchors: cache.max_cache_anchors,
472                        should_speculate: cache.should_speculate,
473                        min_total_token: cache.min_total_token,
474                    })
475            }
476            CloudModel::OpenAi(_) | CloudModel::Google(_) | CloudModel::Zed(_) => None,
477        }
478    }
479
480    fn count_tokens(
481        &self,
482        request: LanguageModelRequest,
483        cx: &AppContext,
484    ) -> BoxFuture<'static, Result<usize>> {
485        match self.model.clone() {
486            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
487            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
488            CloudModel::Google(model) => {
489                let client = self.client.clone();
490                let request = request.into_google(model.id().into());
491                let request = google_ai::CountTokensRequest {
492                    contents: request.contents,
493                };
494                async move {
495                    let request = serde_json::to_string(&request)?;
496                    let response = client
497                        .request(proto::CountLanguageModelTokens {
498                            provider: proto::LanguageModelProvider::Google as i32,
499                            request,
500                        })
501                        .await?;
502                    Ok(response.token_count as usize)
503                }
504                .boxed()
505            }
506            CloudModel::Zed(_) => {
507                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
508            }
509        }
510    }
511
512    fn stream_completion(
513        &self,
514        request: LanguageModelRequest,
515        cx: &AsyncAppContext,
516    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
517        let openai_low_speed_timeout =
518            AllLanguageModelSettings::try_read_global(cx, |s| s.openai.low_speed_timeout.unwrap());
519
520        match &self.model {
521            CloudModel::Anthropic(model) => {
522                let request = request.into_anthropic(
523                    model.id().into(),
524                    model.default_temperature(),
525                    model.max_output_tokens(),
526                );
527                let client = self.client.clone();
528                let llm_api_token = self.llm_api_token.clone();
529                let future = self.request_limiter.stream(async move {
530                    let response = Self::perform_llm_completion(
531                        client.clone(),
532                        llm_api_token,
533                        PerformCompletionParams {
534                            provider: client::LanguageModelProvider::Anthropic,
535                            model: request.model.clone(),
536                            provider_request: RawValue::from_string(serde_json::to_string(
537                                &request,
538                            )?)?,
539                        },
540                        None,
541                    )
542                    .await?;
543                    Ok(map_to_language_model_completion_events(Box::pin(
544                        response_lines(response).map_err(AnthropicError::Other),
545                    )))
546                });
547                async move { Ok(future.await?.boxed()) }.boxed()
548            }
549            CloudModel::OpenAi(model) => {
550                let client = self.client.clone();
551                let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
552                let llm_api_token = self.llm_api_token.clone();
553                let future = self.request_limiter.stream(async move {
554                    let response = Self::perform_llm_completion(
555                        client.clone(),
556                        llm_api_token,
557                        PerformCompletionParams {
558                            provider: client::LanguageModelProvider::OpenAi,
559                            model: request.model.clone(),
560                            provider_request: RawValue::from_string(serde_json::to_string(
561                                &request,
562                            )?)?,
563                        },
564                        openai_low_speed_timeout,
565                    )
566                    .await?;
567                    Ok(open_ai::extract_text_from_events(response_lines(response)))
568                });
569                async move {
570                    Ok(future
571                        .await?
572                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
573                        .boxed())
574                }
575                .boxed()
576            }
577            CloudModel::Google(model) => {
578                let client = self.client.clone();
579                let request = request.into_google(model.id().into());
580                let llm_api_token = self.llm_api_token.clone();
581                let future = self.request_limiter.stream(async move {
582                    let response = Self::perform_llm_completion(
583                        client.clone(),
584                        llm_api_token,
585                        PerformCompletionParams {
586                            provider: client::LanguageModelProvider::Google,
587                            model: request.model.clone(),
588                            provider_request: RawValue::from_string(serde_json::to_string(
589                                &request,
590                            )?)?,
591                        },
592                        None,
593                    )
594                    .await?;
595                    Ok(google_ai::extract_text_from_events(response_lines(
596                        response,
597                    )))
598                });
599                async move {
600                    Ok(future
601                        .await?
602                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
603                        .boxed())
604                }
605                .boxed()
606            }
607            CloudModel::Zed(model) => {
608                let client = self.client.clone();
609                let mut request = request.into_open_ai(model.id().into(), None);
610                request.max_tokens = Some(4000);
611                let llm_api_token = self.llm_api_token.clone();
612                let future = self.request_limiter.stream(async move {
613                    let response = Self::perform_llm_completion(
614                        client.clone(),
615                        llm_api_token,
616                        PerformCompletionParams {
617                            provider: client::LanguageModelProvider::Zed,
618                            model: request.model.clone(),
619                            provider_request: RawValue::from_string(serde_json::to_string(
620                                &request,
621                            )?)?,
622                        },
623                        None,
624                    )
625                    .await?;
626                    Ok(open_ai::extract_text_from_events(response_lines(response)))
627                });
628                async move {
629                    Ok(future
630                        .await?
631                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
632                        .boxed())
633                }
634                .boxed()
635            }
636        }
637    }
638
639    fn use_any_tool(
640        &self,
641        request: LanguageModelRequest,
642        tool_name: String,
643        tool_description: String,
644        input_schema: serde_json::Value,
645        _cx: &AsyncAppContext,
646    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
647        let client = self.client.clone();
648        let llm_api_token = self.llm_api_token.clone();
649
650        match &self.model {
651            CloudModel::Anthropic(model) => {
652                let mut request = request.into_anthropic(
653                    model.tool_model_id().into(),
654                    model.default_temperature(),
655                    model.max_output_tokens(),
656                );
657                request.tool_choice = Some(anthropic::ToolChoice::Tool {
658                    name: tool_name.clone(),
659                });
660                request.tools = vec![anthropic::Tool {
661                    name: tool_name.clone(),
662                    description: tool_description,
663                    input_schema,
664                }];
665
666                self.request_limiter
667                    .run(async move {
668                        let response = Self::perform_llm_completion(
669                            client.clone(),
670                            llm_api_token,
671                            PerformCompletionParams {
672                                provider: client::LanguageModelProvider::Anthropic,
673                                model: request.model.clone(),
674                                provider_request: RawValue::from_string(serde_json::to_string(
675                                    &request,
676                                )?)?,
677                            },
678                            None,
679                        )
680                        .await?;
681
682                        Ok(anthropic::extract_tool_args_from_events(
683                            tool_name,
684                            Box::pin(response_lines(response)),
685                        )
686                        .await?
687                        .boxed())
688                    })
689                    .boxed()
690            }
691            CloudModel::OpenAi(model) => {
692                let mut request =
693                    request.into_open_ai(model.id().into(), model.max_output_tokens());
694                request.tool_choice = Some(open_ai::ToolChoice::Other(
695                    open_ai::ToolDefinition::Function {
696                        function: open_ai::FunctionDefinition {
697                            name: tool_name.clone(),
698                            description: None,
699                            parameters: None,
700                        },
701                    },
702                ));
703                request.tools = vec![open_ai::ToolDefinition::Function {
704                    function: open_ai::FunctionDefinition {
705                        name: tool_name.clone(),
706                        description: Some(tool_description),
707                        parameters: Some(input_schema),
708                    },
709                }];
710
711                self.request_limiter
712                    .run(async move {
713                        let response = Self::perform_llm_completion(
714                            client.clone(),
715                            llm_api_token,
716                            PerformCompletionParams {
717                                provider: client::LanguageModelProvider::OpenAi,
718                                model: request.model.clone(),
719                                provider_request: RawValue::from_string(serde_json::to_string(
720                                    &request,
721                                )?)?,
722                            },
723                            None,
724                        )
725                        .await?;
726
727                        Ok(open_ai::extract_tool_args_from_events(
728                            tool_name,
729                            Box::pin(response_lines(response)),
730                        )
731                        .await?
732                        .boxed())
733                    })
734                    .boxed()
735            }
736            CloudModel::Google(_) => {
737                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
738            }
739            CloudModel::Zed(model) => {
740                // All Zed models are OpenAI-based at the time of writing.
741                let mut request = request.into_open_ai(model.id().into(), None);
742                request.tool_choice = Some(open_ai::ToolChoice::Other(
743                    open_ai::ToolDefinition::Function {
744                        function: open_ai::FunctionDefinition {
745                            name: tool_name.clone(),
746                            description: None,
747                            parameters: None,
748                        },
749                    },
750                ));
751                request.tools = vec![open_ai::ToolDefinition::Function {
752                    function: open_ai::FunctionDefinition {
753                        name: tool_name.clone(),
754                        description: Some(tool_description),
755                        parameters: Some(input_schema),
756                    },
757                }];
758
759                self.request_limiter
760                    .run(async move {
761                        let response = Self::perform_llm_completion(
762                            client.clone(),
763                            llm_api_token,
764                            PerformCompletionParams {
765                                provider: client::LanguageModelProvider::Zed,
766                                model: request.model.clone(),
767                                provider_request: RawValue::from_string(serde_json::to_string(
768                                    &request,
769                                )?)?,
770                            },
771                            None,
772                        )
773                        .await?;
774
775                        Ok(open_ai::extract_tool_args_from_events(
776                            tool_name,
777                            Box::pin(response_lines(response)),
778                        )
779                        .await?
780                        .boxed())
781                    })
782                    .boxed()
783            }
784        }
785    }
786}
787
788fn response_lines<T: DeserializeOwned>(
789    response: Response<AsyncBody>,
790) -> impl Stream<Item = Result<T>> {
791    futures::stream::try_unfold(
792        (String::new(), BufReader::new(response.into_body())),
793        move |(mut line, mut body)| async {
794            match body.read_line(&mut line).await {
795                Ok(0) => Ok(None),
796                Ok(_) => {
797                    let event: T = serde_json::from_str(&line)?;
798                    line.clear();
799                    Ok(Some((event, (line, body))))
800                }
801                Err(e) => Err(e.into()),
802            }
803        },
804    )
805}
806
807impl LlmApiToken {
808    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
809        let lock = self.0.upgradable_read().await;
810        if let Some(token) = lock.as_ref() {
811            Ok(token.to_string())
812        } else {
813            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
814        }
815    }
816
817    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
818        Self::fetch(self.0.write().await, client).await
819    }
820
821    async fn fetch<'a>(
822        mut lock: RwLockWriteGuard<'a, Option<String>>,
823        client: &Arc<Client>,
824    ) -> Result<String> {
825        let response = client.request(proto::GetLlmToken {}).await?;
826        *lock = Some(response.token.clone());
827        Ok(response.token.clone())
828    }
829}
830
831struct ConfigurationView {
832    state: gpui::Model<State>,
833}
834
835impl ConfigurationView {
836    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
837        self.state.update(cx, |state, cx| {
838            state.authenticate(cx).detach_and_log_err(cx);
839        });
840        cx.notify();
841    }
842
843    fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
844        if self.state.read(cx).has_accepted_terms_of_service(cx) {
845            return None;
846        }
847
848        let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
849
850        let terms_button = Button::new("terms_of_service", "Terms of Service")
851            .style(ButtonStyle::Subtle)
852            .icon(IconName::ExternalLink)
853            .icon_color(Color::Muted)
854            .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
855
856        let text =
857            "In order to use Zed AI, please read and accept our terms and conditions to continue:";
858
859        let form = v_flex()
860            .gap_2()
861            .child(Label::new("Terms and Conditions"))
862            .child(Label::new(text))
863            .child(h_flex().justify_center().child(terms_button))
864            .child(
865                h_flex().justify_center().child(
866                    Button::new("accept_terms", "I've read and accept the terms of service")
867                        .style(ButtonStyle::Tinted(TintColor::Accent))
868                        .disabled(accept_terms_disabled)
869                        .on_click({
870                            let state = self.state.downgrade();
871                            move |_, cx| {
872                                state
873                                    .update(cx, |state, cx| state.accept_terms_of_service(cx))
874                                    .ok();
875                            }
876                        }),
877                ),
878            );
879
880        Some(form.into_any())
881    }
882}
883
884impl Render for ConfigurationView {
885    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
886        const ZED_AI_URL: &str = "https://zed.dev/ai";
887        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
888
889        let is_connected = !self.state.read(cx).is_signed_out();
890        let plan = self.state.read(cx).user_store.read(cx).current_plan();
891        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
892
893        let is_pro = plan == Some(proto::Plan::ZedPro);
894        let subscription_text = Label::new(if is_pro {
895            "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
896        } else {
897            "You have basic access to models from Anthropic through the Zed AI Free plan."
898        });
899        let manage_subscription_button = if is_pro {
900            Some(
901                h_flex().child(
902                    Button::new("manage_settings", "Manage Subscription")
903                        .style(ButtonStyle::Tinted(TintColor::Accent))
904                        .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
905                ),
906            )
907        } else if cx.has_flag::<ZedPro>() {
908            Some(
909                h_flex()
910                    .gap_2()
911                    .child(
912                        Button::new("learn_more", "Learn more")
913                            .style(ButtonStyle::Subtle)
914                            .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
915                    )
916                    .child(
917                        Button::new("upgrade", "Upgrade")
918                            .style(ButtonStyle::Subtle)
919                            .color(Color::Accent)
920                            .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
921                    ),
922            )
923        } else {
924            None
925        };
926
927        if is_connected {
928            v_flex()
929                .gap_3()
930                .max_w_4_5()
931                .children(self.render_accept_terms(cx))
932                .when(has_accepted_terms, |this| {
933                    this.child(subscription_text)
934                        .children(manage_subscription_button)
935                })
936        } else {
937            v_flex()
938                .gap_6()
939                .child(Label::new("Use the zed.dev to access language models."))
940                .child(
941                    v_flex()
942                        .gap_2()
943                        .child(
944                            Button::new("sign_in", "Sign in")
945                                .icon_color(Color::Muted)
946                                .icon(IconName::Github)
947                                .icon_position(IconPosition::Start)
948                                .style(ButtonStyle::Filled)
949                                .full_width()
950                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
951                        )
952                        .child(
953                            div().flex().w_full().items_center().child(
954                                Label::new("Sign in to enable collaboration.")
955                                    .color(Color::Muted)
956                                    .size(LabelSize::Small),
957                            ),
958                        ),
959                )
960        }
961    }
962}