cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::provider::anthropic::map_to_language_model_completion_events;
  3use crate::{
  4    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
  5    LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  6    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  7};
  8use anthropic::AnthropicError;
  9use anyhow::{anyhow, Result};
 10use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
 11use collections::BTreeMap;
 12use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
 13use futures::{
 14    future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
 15    TryStreamExt as _,
 16};
 17use gpui::{
 18    AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
 19    Subscription, Task,
 20};
 21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response};
 22use schemars::JsonSchema;
 23use serde::{de::DeserializeOwned, Deserialize, Serialize};
 24use serde_json::value::RawValue;
 25use settings::{Settings, SettingsStore};
 26use smol::{
 27    io::{AsyncReadExt, BufReader},
 28    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 29};
 30use std::time::Duration;
 31use std::{
 32    future,
 33    sync::{Arc, LazyLock},
 34};
 35use strum::IntoEnumIterator;
 36use ui::{prelude::*, TintColor};
 37
 38use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
 39
 40use super::anthropic::count_anthropic_tokens;
 41
 42pub const PROVIDER_ID: &str = "zed.dev";
 43pub const PROVIDER_NAME: &str = "Zed";
 44
 45const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
 46    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
 47
 48fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
 49    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
 50        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
 51            .map(|json| serde_json::from_str(json).unwrap())
 52            .unwrap_or_default()
 53    });
 54    ADDITIONAL_MODELS.as_slice()
 55}
 56
 57#[derive(Default, Clone, Debug, PartialEq)]
 58pub struct ZedDotDevSettings {
 59    pub available_models: Vec<AvailableModel>,
 60    pub low_speed_timeout: Option<Duration>,
 61}
 62
 63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 64#[serde(rename_all = "lowercase")]
 65pub enum AvailableProvider {
 66    Anthropic,
 67    OpenAi,
 68    Google,
 69}
 70
 71#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 72pub struct AvailableModel {
 73    /// The provider of the language model.
 74    pub provider: AvailableProvider,
 75    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
 76    pub name: String,
 77    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 78    pub display_name: Option<String>,
 79    /// The size of the context window, indicating the maximum number of tokens the model can process.
 80    pub max_tokens: usize,
 81    /// The maximum number of output tokens allowed by the model.
 82    pub max_output_tokens: Option<u32>,
 83    /// The maximum number of completion tokens allowed by the model (o1-* only)
 84    pub max_completion_tokens: Option<u32>,
 85    /// Override this model with a different Anthropic model for tool calls.
 86    pub tool_override: Option<String>,
 87    /// Indicates whether this custom model supports caching.
 88    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 89    /// The default temperature to use for this model.
 90    pub default_temperature: Option<f32>,
 91}
 92
 93pub struct CloudLanguageModelProvider {
 94    client: Arc<Client>,
 95    llm_api_token: LlmApiToken,
 96    state: gpui::Model<State>,
 97    _maintain_client_status: Task<()>,
 98}
 99
100pub struct State {
101    client: Arc<Client>,
102    user_store: Model<UserStore>,
103    status: client::Status,
104    accept_terms: Option<Task<Result<()>>>,
105    _subscription: Subscription,
106}
107
108impl State {
109    fn is_signed_out(&self) -> bool {
110        self.status.is_signed_out()
111    }
112
113    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
114        let client = self.client.clone();
115        cx.spawn(move |this, mut cx| async move {
116            client.authenticate_and_connect(true, &cx).await?;
117            this.update(&mut cx, |_, cx| cx.notify())
118        })
119    }
120
121    fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
122        self.user_store
123            .read(cx)
124            .current_user_has_accepted_terms()
125            .unwrap_or(false)
126    }
127
128    fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
129        let user_store = self.user_store.clone();
130        self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
131            let _ = user_store
132                .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
133                .await;
134            this.update(&mut cx, |this, cx| {
135                this.accept_terms = None;
136                cx.notify()
137            })
138        }));
139    }
140}
141
142impl CloudLanguageModelProvider {
143    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
144        let mut status_rx = client.status();
145        let status = *status_rx.borrow();
146
147        let state = cx.new_model(|cx| State {
148            client: client.clone(),
149            user_store,
150            status,
151            accept_terms: None,
152            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
153                cx.notify();
154            }),
155        });
156
157        let state_ref = state.downgrade();
158        let maintain_client_status = cx.spawn(|mut cx| async move {
159            while let Some(status) = status_rx.next().await {
160                if let Some(this) = state_ref.upgrade() {
161                    _ = this.update(&mut cx, |this, cx| {
162                        if this.status != status {
163                            this.status = status;
164                            cx.notify();
165                        }
166                    });
167                } else {
168                    break;
169                }
170            }
171        });
172
173        Self {
174            client,
175            state,
176            llm_api_token: LlmApiToken::default(),
177            _maintain_client_status: maintain_client_status,
178        }
179    }
180}
181
182impl LanguageModelProviderState for CloudLanguageModelProvider {
183    type ObservableEntity = State;
184
185    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
186        Some(self.state.clone())
187    }
188}
189
190impl LanguageModelProvider for CloudLanguageModelProvider {
191    fn id(&self) -> LanguageModelProviderId {
192        LanguageModelProviderId(PROVIDER_ID.into())
193    }
194
195    fn name(&self) -> LanguageModelProviderName {
196        LanguageModelProviderName(PROVIDER_NAME.into())
197    }
198
199    fn icon(&self) -> IconName {
200        IconName::AiZed
201    }
202
203    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
204        let mut models = BTreeMap::default();
205
206        if cx.is_staff() {
207            for model in anthropic::Model::iter() {
208                if !matches!(model, anthropic::Model::Custom { .. }) {
209                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
210                }
211            }
212            for model in open_ai::Model::iter() {
213                if !matches!(model, open_ai::Model::Custom { .. }) {
214                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
215                }
216            }
217            for model in google_ai::Model::iter() {
218                if !matches!(model, google_ai::Model::Custom { .. }) {
219                    models.insert(model.id().to_string(), CloudModel::Google(model));
220                }
221            }
222            for model in ZedModel::iter() {
223                models.insert(model.id().to_string(), CloudModel::Zed(model));
224            }
225        } else {
226            models.insert(
227                anthropic::Model::Claude3_5Sonnet.id().to_string(),
228                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
229            );
230        }
231
232        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
233            zed_cloud_provider_additional_models()
234        } else {
235            &[]
236        };
237
238        // Override with available models from settings
239        for model in AllLanguageModelSettings::get_global(cx)
240            .zed_dot_dev
241            .available_models
242            .iter()
243            .chain(llm_closed_beta_models)
244            .cloned()
245        {
246            let model = match model.provider {
247                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
248                    name: model.name.clone(),
249                    display_name: model.display_name.clone(),
250                    max_tokens: model.max_tokens,
251                    tool_override: model.tool_override.clone(),
252                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
253                        anthropic::AnthropicModelCacheConfiguration {
254                            max_cache_anchors: config.max_cache_anchors,
255                            should_speculate: config.should_speculate,
256                            min_total_token: config.min_total_token,
257                        }
258                    }),
259                    default_temperature: model.default_temperature,
260                    max_output_tokens: model.max_output_tokens,
261                }),
262                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
263                    name: model.name.clone(),
264                    display_name: model.display_name.clone(),
265                    max_tokens: model.max_tokens,
266                    max_output_tokens: model.max_output_tokens,
267                    max_completion_tokens: model.max_completion_tokens,
268                }),
269                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
270                    name: model.name.clone(),
271                    display_name: model.display_name.clone(),
272                    max_tokens: model.max_tokens,
273                }),
274            };
275            models.insert(model.id().to_string(), model.clone());
276        }
277
278        models
279            .into_values()
280            .map(|model| {
281                Arc::new(CloudLanguageModel {
282                    id: LanguageModelId::from(model.id().to_string()),
283                    model,
284                    llm_api_token: self.llm_api_token.clone(),
285                    client: self.client.clone(),
286                    request_limiter: RateLimiter::new(4),
287                }) as Arc<dyn LanguageModel>
288            })
289            .collect()
290    }
291
292    fn is_authenticated(&self, cx: &AppContext) -> bool {
293        !self.state.read(cx).is_signed_out()
294    }
295
296    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
297        Task::ready(Ok(()))
298    }
299
300    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
301        cx.new_view(|_cx| ConfigurationView {
302            state: self.state.clone(),
303        })
304        .into()
305    }
306
307    fn must_accept_terms(&self, cx: &AppContext) -> bool {
308        !self.state.read(cx).has_accepted_terms_of_service(cx)
309    }
310
311    fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
312        let state = self.state.read(cx);
313
314        let terms = [(
315            "terms_of_service",
316            "Terms of Service",
317            "https://zed.dev/terms-of-service",
318        )]
319        .map(|(id, label, url)| {
320            Button::new(id, label)
321                .style(ButtonStyle::Subtle)
322                .icon(IconName::ExternalLink)
323                .icon_size(IconSize::XSmall)
324                .icon_color(Color::Muted)
325                .on_click(move |_, cx| cx.open_url(url))
326        });
327
328        if state.has_accepted_terms_of_service(cx) {
329            None
330        } else {
331            let disabled = state.accept_terms.is_some();
332            Some(
333                v_flex()
334                    .gap_2()
335                    .child(
336                        v_flex()
337                            .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
338                            .child(
339                                Label::new(
340                                    "Please read and accept our terms and conditions to continue.",
341                                )
342                                .size(LabelSize::Small),
343                            ),
344                    )
345                    .child(v_flex().gap_1().children(terms))
346                    .child(
347                        h_flex().justify_end().child(
348                            Button::new("accept_terms", "I've read it and accept it")
349                                .disabled(disabled)
350                                .on_click({
351                                    let state = self.state.downgrade();
352                                    move |_, cx| {
353                                        state
354                                            .update(cx, |state, cx| {
355                                                state.accept_terms_of_service(cx)
356                                            })
357                                            .ok();
358                                    }
359                                }),
360                        ),
361                    )
362                    .into_any(),
363            )
364        }
365    }
366
367    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
368        Task::ready(Ok(()))
369    }
370}
371
372pub struct CloudLanguageModel {
373    id: LanguageModelId,
374    model: CloudModel,
375    llm_api_token: LlmApiToken,
376    client: Arc<Client>,
377    request_limiter: RateLimiter,
378}
379
380#[derive(Clone, Default)]
381struct LlmApiToken(Arc<RwLock<Option<String>>>);
382
383impl CloudLanguageModel {
384    async fn perform_llm_completion(
385        client: Arc<Client>,
386        llm_api_token: LlmApiToken,
387        body: PerformCompletionParams,
388        low_speed_timeout: Option<Duration>,
389    ) -> Result<Response<AsyncBody>> {
390        let http_client = &client.http_client();
391
392        let mut token = llm_api_token.acquire(&client).await?;
393        let mut did_retry = false;
394
395        let response = loop {
396            let mut request_builder = http_client::Request::builder();
397            if let Some(low_speed_timeout) = low_speed_timeout {
398                request_builder = request_builder.read_timeout(low_speed_timeout);
399            };
400            let request = request_builder
401                .method(Method::POST)
402                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
403                .header("Content-Type", "application/json")
404                .header("Authorization", format!("Bearer {token}"))
405                .body(serde_json::to_string(&body)?.into())?;
406            let mut response = http_client.send(request).await?;
407            if response.status().is_success() {
408                break response;
409            } else if !did_retry
410                && response
411                    .headers()
412                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
413                    .is_some()
414            {
415                did_retry = true;
416                token = llm_api_token.refresh(&client).await?;
417            } else {
418                let mut body = String::new();
419                response.body_mut().read_to_string(&mut body).await?;
420                break Err(anyhow!(
421                    "cloud language model completion failed with status {}: {body}",
422                    response.status()
423                ))?;
424            }
425        };
426
427        Ok(response)
428    }
429}
430
431impl LanguageModel for CloudLanguageModel {
432    fn id(&self) -> LanguageModelId {
433        self.id.clone()
434    }
435
436    fn name(&self) -> LanguageModelName {
437        LanguageModelName::from(self.model.display_name().to_string())
438    }
439
440    fn icon(&self) -> Option<IconName> {
441        self.model.icon()
442    }
443
444    fn provider_id(&self) -> LanguageModelProviderId {
445        LanguageModelProviderId(PROVIDER_ID.into())
446    }
447
448    fn provider_name(&self) -> LanguageModelProviderName {
449        LanguageModelProviderName(PROVIDER_NAME.into())
450    }
451
452    fn telemetry_id(&self) -> String {
453        format!("zed.dev/{}", self.model.id())
454    }
455
456    fn availability(&self) -> LanguageModelAvailability {
457        self.model.availability()
458    }
459
460    fn max_token_count(&self) -> usize {
461        self.model.max_token_count()
462    }
463
464    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
465        match &self.model {
466            CloudModel::Anthropic(model) => {
467                model
468                    .cache_configuration()
469                    .map(|cache| LanguageModelCacheConfiguration {
470                        max_cache_anchors: cache.max_cache_anchors,
471                        should_speculate: cache.should_speculate,
472                        min_total_token: cache.min_total_token,
473                    })
474            }
475            CloudModel::OpenAi(_) | CloudModel::Google(_) | CloudModel::Zed(_) => None,
476        }
477    }
478
479    fn count_tokens(
480        &self,
481        request: LanguageModelRequest,
482        cx: &AppContext,
483    ) -> BoxFuture<'static, Result<usize>> {
484        match self.model.clone() {
485            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
486            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
487            CloudModel::Google(model) => {
488                let client = self.client.clone();
489                let request = request.into_google(model.id().into());
490                let request = google_ai::CountTokensRequest {
491                    contents: request.contents,
492                };
493                async move {
494                    let request = serde_json::to_string(&request)?;
495                    let response = client
496                        .request(proto::CountLanguageModelTokens {
497                            provider: proto::LanguageModelProvider::Google as i32,
498                            request,
499                        })
500                        .await?;
501                    Ok(response.token_count as usize)
502                }
503                .boxed()
504            }
505            CloudModel::Zed(_) => {
506                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
507            }
508        }
509    }
510
511    fn stream_completion(
512        &self,
513        request: LanguageModelRequest,
514        cx: &AsyncAppContext,
515    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
516        let openai_low_speed_timeout =
517            AllLanguageModelSettings::try_read_global(cx, |s| s.openai.low_speed_timeout.unwrap());
518
519        match &self.model {
520            CloudModel::Anthropic(model) => {
521                let request = request.into_anthropic(
522                    model.id().into(),
523                    model.default_temperature(),
524                    model.max_output_tokens(),
525                );
526                let client = self.client.clone();
527                let llm_api_token = self.llm_api_token.clone();
528                let future = self.request_limiter.stream(async move {
529                    let response = Self::perform_llm_completion(
530                        client.clone(),
531                        llm_api_token,
532                        PerformCompletionParams {
533                            provider: client::LanguageModelProvider::Anthropic,
534                            model: request.model.clone(),
535                            provider_request: RawValue::from_string(serde_json::to_string(
536                                &request,
537                            )?)?,
538                        },
539                        None,
540                    )
541                    .await?;
542                    Ok(map_to_language_model_completion_events(Box::pin(
543                        response_lines(response).map_err(AnthropicError::Other),
544                    )))
545                });
546                async move { Ok(future.await?.boxed()) }.boxed()
547            }
548            CloudModel::OpenAi(model) => {
549                let client = self.client.clone();
550                let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
551                let llm_api_token = self.llm_api_token.clone();
552                let future = self.request_limiter.stream(async move {
553                    let response = Self::perform_llm_completion(
554                        client.clone(),
555                        llm_api_token,
556                        PerformCompletionParams {
557                            provider: client::LanguageModelProvider::OpenAi,
558                            model: request.model.clone(),
559                            provider_request: RawValue::from_string(serde_json::to_string(
560                                &request,
561                            )?)?,
562                        },
563                        openai_low_speed_timeout,
564                    )
565                    .await?;
566                    Ok(open_ai::extract_text_from_events(response_lines(response)))
567                });
568                async move {
569                    Ok(future
570                        .await?
571                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
572                        .boxed())
573                }
574                .boxed()
575            }
576            CloudModel::Google(model) => {
577                let client = self.client.clone();
578                let request = request.into_google(model.id().into());
579                let llm_api_token = self.llm_api_token.clone();
580                let future = self.request_limiter.stream(async move {
581                    let response = Self::perform_llm_completion(
582                        client.clone(),
583                        llm_api_token,
584                        PerformCompletionParams {
585                            provider: client::LanguageModelProvider::Google,
586                            model: request.model.clone(),
587                            provider_request: RawValue::from_string(serde_json::to_string(
588                                &request,
589                            )?)?,
590                        },
591                        None,
592                    )
593                    .await?;
594                    Ok(google_ai::extract_text_from_events(response_lines(
595                        response,
596                    )))
597                });
598                async move {
599                    Ok(future
600                        .await?
601                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
602                        .boxed())
603                }
604                .boxed()
605            }
606            CloudModel::Zed(model) => {
607                let client = self.client.clone();
608                let mut request = request.into_open_ai(model.id().into(), None);
609                request.max_tokens = Some(4000);
610                let llm_api_token = self.llm_api_token.clone();
611                let future = self.request_limiter.stream(async move {
612                    let response = Self::perform_llm_completion(
613                        client.clone(),
614                        llm_api_token,
615                        PerformCompletionParams {
616                            provider: client::LanguageModelProvider::Zed,
617                            model: request.model.clone(),
618                            provider_request: RawValue::from_string(serde_json::to_string(
619                                &request,
620                            )?)?,
621                        },
622                        None,
623                    )
624                    .await?;
625                    Ok(open_ai::extract_text_from_events(response_lines(response)))
626                });
627                async move {
628                    Ok(future
629                        .await?
630                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
631                        .boxed())
632                }
633                .boxed()
634            }
635        }
636    }
637
638    fn use_any_tool(
639        &self,
640        request: LanguageModelRequest,
641        tool_name: String,
642        tool_description: String,
643        input_schema: serde_json::Value,
644        _cx: &AsyncAppContext,
645    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
646        let client = self.client.clone();
647        let llm_api_token = self.llm_api_token.clone();
648
649        match &self.model {
650            CloudModel::Anthropic(model) => {
651                let mut request = request.into_anthropic(
652                    model.tool_model_id().into(),
653                    model.default_temperature(),
654                    model.max_output_tokens(),
655                );
656                request.tool_choice = Some(anthropic::ToolChoice::Tool {
657                    name: tool_name.clone(),
658                });
659                request.tools = vec![anthropic::Tool {
660                    name: tool_name.clone(),
661                    description: tool_description,
662                    input_schema,
663                }];
664
665                self.request_limiter
666                    .run(async move {
667                        let response = Self::perform_llm_completion(
668                            client.clone(),
669                            llm_api_token,
670                            PerformCompletionParams {
671                                provider: client::LanguageModelProvider::Anthropic,
672                                model: request.model.clone(),
673                                provider_request: RawValue::from_string(serde_json::to_string(
674                                    &request,
675                                )?)?,
676                            },
677                            None,
678                        )
679                        .await?;
680
681                        Ok(anthropic::extract_tool_args_from_events(
682                            tool_name,
683                            Box::pin(response_lines(response)),
684                        )
685                        .await?
686                        .boxed())
687                    })
688                    .boxed()
689            }
690            CloudModel::OpenAi(model) => {
691                let mut request =
692                    request.into_open_ai(model.id().into(), model.max_output_tokens());
693                request.tool_choice = Some(open_ai::ToolChoice::Other(
694                    open_ai::ToolDefinition::Function {
695                        function: open_ai::FunctionDefinition {
696                            name: tool_name.clone(),
697                            description: None,
698                            parameters: None,
699                        },
700                    },
701                ));
702                request.tools = vec![open_ai::ToolDefinition::Function {
703                    function: open_ai::FunctionDefinition {
704                        name: tool_name.clone(),
705                        description: Some(tool_description),
706                        parameters: Some(input_schema),
707                    },
708                }];
709
710                self.request_limiter
711                    .run(async move {
712                        let response = Self::perform_llm_completion(
713                            client.clone(),
714                            llm_api_token,
715                            PerformCompletionParams {
716                                provider: client::LanguageModelProvider::OpenAi,
717                                model: request.model.clone(),
718                                provider_request: RawValue::from_string(serde_json::to_string(
719                                    &request,
720                                )?)?,
721                            },
722                            None,
723                        )
724                        .await?;
725
726                        Ok(open_ai::extract_tool_args_from_events(
727                            tool_name,
728                            Box::pin(response_lines(response)),
729                        )
730                        .await?
731                        .boxed())
732                    })
733                    .boxed()
734            }
735            CloudModel::Google(_) => {
736                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
737            }
738            CloudModel::Zed(model) => {
739                // All Zed models are OpenAI-based at the time of writing.
740                let mut request = request.into_open_ai(model.id().into(), None);
741                request.tool_choice = Some(open_ai::ToolChoice::Other(
742                    open_ai::ToolDefinition::Function {
743                        function: open_ai::FunctionDefinition {
744                            name: tool_name.clone(),
745                            description: None,
746                            parameters: None,
747                        },
748                    },
749                ));
750                request.tools = vec![open_ai::ToolDefinition::Function {
751                    function: open_ai::FunctionDefinition {
752                        name: tool_name.clone(),
753                        description: Some(tool_description),
754                        parameters: Some(input_schema),
755                    },
756                }];
757
758                self.request_limiter
759                    .run(async move {
760                        let response = Self::perform_llm_completion(
761                            client.clone(),
762                            llm_api_token,
763                            PerformCompletionParams {
764                                provider: client::LanguageModelProvider::Zed,
765                                model: request.model.clone(),
766                                provider_request: RawValue::from_string(serde_json::to_string(
767                                    &request,
768                                )?)?,
769                            },
770                            None,
771                        )
772                        .await?;
773
774                        Ok(open_ai::extract_tool_args_from_events(
775                            tool_name,
776                            Box::pin(response_lines(response)),
777                        )
778                        .await?
779                        .boxed())
780                    })
781                    .boxed()
782            }
783        }
784    }
785}
786
787fn response_lines<T: DeserializeOwned>(
788    response: Response<AsyncBody>,
789) -> impl Stream<Item = Result<T>> {
790    futures::stream::try_unfold(
791        (String::new(), BufReader::new(response.into_body())),
792        move |(mut line, mut body)| async {
793            match body.read_line(&mut line).await {
794                Ok(0) => Ok(None),
795                Ok(_) => {
796                    let event: T = serde_json::from_str(&line)?;
797                    line.clear();
798                    Ok(Some((event, (line, body))))
799                }
800                Err(e) => Err(e.into()),
801            }
802        },
803    )
804}
805
806impl LlmApiToken {
807    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
808        let lock = self.0.upgradable_read().await;
809        if let Some(token) = lock.as_ref() {
810            Ok(token.to_string())
811        } else {
812            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
813        }
814    }
815
816    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
817        Self::fetch(self.0.write().await, client).await
818    }
819
820    async fn fetch<'a>(
821        mut lock: RwLockWriteGuard<'a, Option<String>>,
822        client: &Arc<Client>,
823    ) -> Result<String> {
824        let response = client.request(proto::GetLlmToken {}).await?;
825        *lock = Some(response.token.clone());
826        Ok(response.token.clone())
827    }
828}
829
830struct ConfigurationView {
831    state: gpui::Model<State>,
832}
833
834impl ConfigurationView {
835    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
836        self.state.update(cx, |state, cx| {
837            state.authenticate(cx).detach_and_log_err(cx);
838        });
839        cx.notify();
840    }
841
842    fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
843        if self.state.read(cx).has_accepted_terms_of_service(cx) {
844            return None;
845        }
846
847        let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
848
849        let terms_button = Button::new("terms_of_service", "Terms of Service")
850            .style(ButtonStyle::Subtle)
851            .icon(IconName::ExternalLink)
852            .icon_color(Color::Muted)
853            .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
854
855        let text =
856            "In order to use Zed AI, please read and accept our terms and conditions to continue:";
857
858        let form = v_flex()
859            .gap_2()
860            .child(Label::new("Terms and Conditions"))
861            .child(Label::new(text))
862            .child(h_flex().justify_center().child(terms_button))
863            .child(
864                h_flex().justify_center().child(
865                    Button::new("accept_terms", "I've read and accept the terms of service")
866                        .style(ButtonStyle::Tinted(TintColor::Accent))
867                        .disabled(accept_terms_disabled)
868                        .on_click({
869                            let state = self.state.downgrade();
870                            move |_, cx| {
871                                state
872                                    .update(cx, |state, cx| state.accept_terms_of_service(cx))
873                                    .ok();
874                            }
875                        }),
876                ),
877            );
878
879        Some(form.into_any())
880    }
881}
882
883impl Render for ConfigurationView {
884    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
885        const ZED_AI_URL: &str = "https://zed.dev/ai";
886        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
887
888        let is_connected = !self.state.read(cx).is_signed_out();
889        let plan = self.state.read(cx).user_store.read(cx).current_plan();
890        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
891
892        let is_pro = plan == Some(proto::Plan::ZedPro);
893        let subscription_text = Label::new(if is_pro {
894            "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
895        } else {
896            "You have basic access to models from Anthropic through the Zed AI Free plan."
897        });
898        let manage_subscription_button = if is_pro {
899            Some(
900                h_flex().child(
901                    Button::new("manage_settings", "Manage Subscription")
902                        .style(ButtonStyle::Tinted(TintColor::Accent))
903                        .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
904                ),
905            )
906        } else if cx.has_flag::<ZedPro>() {
907            Some(
908                h_flex()
909                    .gap_2()
910                    .child(
911                        Button::new("learn_more", "Learn more")
912                            .style(ButtonStyle::Subtle)
913                            .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
914                    )
915                    .child(
916                        Button::new("upgrade", "Upgrade")
917                            .style(ButtonStyle::Subtle)
918                            .color(Color::Accent)
919                            .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
920                    ),
921            )
922        } else {
923            None
924        };
925
926        if is_connected {
927            v_flex()
928                .gap_3()
929                .max_w_4_5()
930                .children(self.render_accept_terms(cx))
931                .when(has_accepted_terms, |this| {
932                    this.child(subscription_text)
933                        .children(manage_subscription_button)
934                })
935        } else {
936            v_flex()
937                .gap_6()
938                .child(Label::new("Use the zed.dev to access language models."))
939                .child(
940                    v_flex()
941                        .gap_2()
942                        .child(
943                            Button::new("sign_in", "Sign in")
944                                .icon_color(Color::Muted)
945                                .icon(IconName::Github)
946                                .icon_position(IconPosition::Start)
947                                .style(ButtonStyle::Filled)
948                                .full_width()
949                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
950                        )
951                        .child(
952                            div().flex().w_full().items_center().child(
953                                Label::new("Sign in to enable collaboration.")
954                                    .color(Color::Muted)
955                                    .size(LabelSize::Small),
956                            ),
957                        ),
958                )
959        }
960    }
961}