cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
  4    LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  6};
  7use anthropic::AnthropicError;
  8use anyhow::{anyhow, Result};
  9use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
 10use collections::BTreeMap;
 11use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
 12use futures::{
 13    future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
 14    TryStreamExt as _,
 15};
 16use gpui::{
 17    AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
 18    Subscription, Task,
 19};
 20use http_client::{AsyncBody, HttpClient, Method, Response};
 21use schemars::JsonSchema;
 22use serde::{de::DeserializeOwned, Deserialize, Serialize};
 23use serde_json::value::RawValue;
 24use settings::{Settings, SettingsStore};
 25use smol::{
 26    io::{AsyncReadExt, BufReader},
 27    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 28};
 29use std::{
 30    future,
 31    sync::{Arc, LazyLock},
 32};
 33use strum::IntoEnumIterator;
 34use ui::{prelude::*, TintColor};
 35
 36use crate::{LanguageModelAvailability, LanguageModelProvider};
 37
 38use super::anthropic::count_anthropic_tokens;
 39
 40pub const PROVIDER_ID: &str = "zed.dev";
 41pub const PROVIDER_NAME: &str = "Zed";
 42
 43const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
 44    option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
 45
 46fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
 47    static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
 48        ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
 49            .map(|json| serde_json::from_str(json).unwrap())
 50            .unwrap_or(Vec::new())
 51    });
 52    ADDITIONAL_MODELS.as_slice()
 53}
 54
 55#[derive(Default, Clone, Debug, PartialEq)]
 56pub struct ZedDotDevSettings {
 57    pub available_models: Vec<AvailableModel>,
 58}
 59
 60#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 61#[serde(rename_all = "lowercase")]
 62pub enum AvailableProvider {
 63    Anthropic,
 64    OpenAi,
 65    Google,
 66}
 67
 68#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 69pub struct AvailableModel {
 70    /// The provider of the language model.
 71    pub provider: AvailableProvider,
 72    /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
 73    pub name: String,
 74    /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 75    pub display_name: Option<String>,
 76    /// The size of the context window, indicating the maximum number of tokens the model can process.
 77    pub max_tokens: usize,
 78    /// The maximum number of output tokens allowed by the model.
 79    pub max_output_tokens: Option<u32>,
 80    /// Override this model with a different Anthropic model for tool calls.
 81    pub tool_override: Option<String>,
 82    /// Indicates whether this custom model supports caching.
 83    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 84}
 85
 86pub struct CloudLanguageModelProvider {
 87    client: Arc<Client>,
 88    llm_api_token: LlmApiToken,
 89    state: gpui::Model<State>,
 90    _maintain_client_status: Task<()>,
 91}
 92
 93pub struct State {
 94    client: Arc<Client>,
 95    user_store: Model<UserStore>,
 96    status: client::Status,
 97    accept_terms: Option<Task<Result<()>>>,
 98    _subscription: Subscription,
 99}
100
101impl State {
102    fn is_signed_out(&self) -> bool {
103        self.status.is_signed_out()
104    }
105
106    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
107        let client = self.client.clone();
108        cx.spawn(move |this, mut cx| async move {
109            client.authenticate_and_connect(true, &cx).await?;
110            this.update(&mut cx, |_, cx| cx.notify())
111        })
112    }
113
114    fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
115        self.user_store
116            .read(cx)
117            .current_user_has_accepted_terms()
118            .unwrap_or(false)
119    }
120
121    fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
122        let user_store = self.user_store.clone();
123        self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
124            let _ = user_store
125                .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
126                .await;
127            this.update(&mut cx, |this, cx| {
128                this.accept_terms = None;
129                cx.notify()
130            })
131        }));
132    }
133}
134
135impl CloudLanguageModelProvider {
136    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
137        let mut status_rx = client.status();
138        let status = *status_rx.borrow();
139
140        let state = cx.new_model(|cx| State {
141            client: client.clone(),
142            user_store,
143            status,
144            accept_terms: None,
145            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
146                cx.notify();
147            }),
148        });
149
150        let state_ref = state.downgrade();
151        let maintain_client_status = cx.spawn(|mut cx| async move {
152            while let Some(status) = status_rx.next().await {
153                if let Some(this) = state_ref.upgrade() {
154                    _ = this.update(&mut cx, |this, cx| {
155                        if this.status != status {
156                            this.status = status;
157                            cx.notify();
158                        }
159                    });
160                } else {
161                    break;
162                }
163            }
164        });
165
166        Self {
167            client,
168            state,
169            llm_api_token: LlmApiToken::default(),
170            _maintain_client_status: maintain_client_status,
171        }
172    }
173}
174
175impl LanguageModelProviderState for CloudLanguageModelProvider {
176    type ObservableEntity = State;
177
178    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
179        Some(self.state.clone())
180    }
181}
182
183impl LanguageModelProvider for CloudLanguageModelProvider {
184    fn id(&self) -> LanguageModelProviderId {
185        LanguageModelProviderId(PROVIDER_ID.into())
186    }
187
188    fn name(&self) -> LanguageModelProviderName {
189        LanguageModelProviderName(PROVIDER_NAME.into())
190    }
191
192    fn icon(&self) -> IconName {
193        IconName::AiZed
194    }
195
196    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
197        let mut models = BTreeMap::default();
198
199        if cx.is_staff() {
200            for model in anthropic::Model::iter() {
201                if !matches!(model, anthropic::Model::Custom { .. }) {
202                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
203                }
204            }
205            for model in open_ai::Model::iter() {
206                if !matches!(model, open_ai::Model::Custom { .. }) {
207                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
208                }
209            }
210            for model in google_ai::Model::iter() {
211                if !matches!(model, google_ai::Model::Custom { .. }) {
212                    models.insert(model.id().to_string(), CloudModel::Google(model));
213                }
214            }
215            for model in ZedModel::iter() {
216                models.insert(model.id().to_string(), CloudModel::Zed(model));
217            }
218        } else {
219            models.insert(
220                anthropic::Model::Claude3_5Sonnet.id().to_string(),
221                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
222            );
223        }
224
225        let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
226            zed_cloud_provider_additional_models()
227        } else {
228            &[]
229        };
230
231        // Override with available models from settings
232        for model in AllLanguageModelSettings::get_global(cx)
233            .zed_dot_dev
234            .available_models
235            .iter()
236            .chain(llm_closed_beta_models)
237            .cloned()
238        {
239            let model = match model.provider {
240                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
241                    name: model.name.clone(),
242                    display_name: model.display_name.clone(),
243                    max_tokens: model.max_tokens,
244                    tool_override: model.tool_override.clone(),
245                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
246                        anthropic::AnthropicModelCacheConfiguration {
247                            max_cache_anchors: config.max_cache_anchors,
248                            should_speculate: config.should_speculate,
249                            min_total_token: config.min_total_token,
250                        }
251                    }),
252                    max_output_tokens: model.max_output_tokens,
253                }),
254                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
255                    name: model.name.clone(),
256                    max_tokens: model.max_tokens,
257                    max_output_tokens: model.max_output_tokens,
258                }),
259                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
260                    name: model.name.clone(),
261                    max_tokens: model.max_tokens,
262                }),
263            };
264            models.insert(model.id().to_string(), model.clone());
265        }
266
267        models
268            .into_values()
269            .map(|model| {
270                Arc::new(CloudLanguageModel {
271                    id: LanguageModelId::from(model.id().to_string()),
272                    model,
273                    llm_api_token: self.llm_api_token.clone(),
274                    client: self.client.clone(),
275                    request_limiter: RateLimiter::new(4),
276                }) as Arc<dyn LanguageModel>
277            })
278            .collect()
279    }
280
281    fn is_authenticated(&self, cx: &AppContext) -> bool {
282        !self.state.read(cx).is_signed_out()
283    }
284
285    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
286        Task::ready(Ok(()))
287    }
288
289    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
290        cx.new_view(|_cx| ConfigurationView {
291            state: self.state.clone(),
292        })
293        .into()
294    }
295
296    fn must_accept_terms(&self, cx: &AppContext) -> bool {
297        !self.state.read(cx).has_accepted_terms_of_service(cx)
298    }
299
300    fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
301        let state = self.state.read(cx);
302
303        let terms = [(
304            "terms_of_service",
305            "Terms of Service",
306            "https://zed.dev/terms-of-service",
307        )]
308        .map(|(id, label, url)| {
309            Button::new(id, label)
310                .style(ButtonStyle::Subtle)
311                .icon(IconName::ExternalLink)
312                .icon_size(IconSize::XSmall)
313                .icon_color(Color::Muted)
314                .on_click(move |_, cx| cx.open_url(url))
315        });
316
317        if state.has_accepted_terms_of_service(cx) {
318            None
319        } else {
320            let disabled = state.accept_terms.is_some();
321            Some(
322                v_flex()
323                    .gap_2()
324                    .child(
325                        v_flex()
326                            .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
327                            .child(
328                                Label::new(
329                                    "Please read and accept our terms and conditions to continue.",
330                                )
331                                .size(LabelSize::Small),
332                            ),
333                    )
334                    .child(v_flex().gap_1().children(terms))
335                    .child(
336                        h_flex().justify_end().child(
337                            Button::new("accept_terms", "I've read it and accept it")
338                                .disabled(disabled)
339                                .on_click({
340                                    let state = self.state.downgrade();
341                                    move |_, cx| {
342                                        state
343                                            .update(cx, |state, cx| {
344                                                state.accept_terms_of_service(cx)
345                                            })
346                                            .ok();
347                                    }
348                                }),
349                        ),
350                    )
351                    .into_any(),
352            )
353        }
354    }
355
356    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
357        Task::ready(Ok(()))
358    }
359}
360
361pub struct CloudLanguageModel {
362    id: LanguageModelId,
363    model: CloudModel,
364    llm_api_token: LlmApiToken,
365    client: Arc<Client>,
366    request_limiter: RateLimiter,
367}
368
369#[derive(Clone, Default)]
370struct LlmApiToken(Arc<RwLock<Option<String>>>);
371
372impl CloudLanguageModel {
373    async fn perform_llm_completion(
374        client: Arc<Client>,
375        llm_api_token: LlmApiToken,
376        body: PerformCompletionParams,
377    ) -> Result<Response<AsyncBody>> {
378        let http_client = &client.http_client();
379
380        let mut token = llm_api_token.acquire(&client).await?;
381        let mut did_retry = false;
382
383        let response = loop {
384            let request = http_client::Request::builder()
385                .method(Method::POST)
386                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
387                .header("Content-Type", "application/json")
388                .header("Authorization", format!("Bearer {token}"))
389                .body(serde_json::to_string(&body)?.into())?;
390            let mut response = http_client.send(request).await?;
391            if response.status().is_success() {
392                break response;
393            } else if !did_retry
394                && response
395                    .headers()
396                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
397                    .is_some()
398            {
399                did_retry = true;
400                token = llm_api_token.refresh(&client).await?;
401            } else {
402                let mut body = String::new();
403                response.body_mut().read_to_string(&mut body).await?;
404                break Err(anyhow!(
405                    "cloud language model completion failed with status {}: {body}",
406                    response.status()
407                ))?;
408            }
409        };
410
411        Ok(response)
412    }
413}
414
415impl LanguageModel for CloudLanguageModel {
416    fn id(&self) -> LanguageModelId {
417        self.id.clone()
418    }
419
420    fn name(&self) -> LanguageModelName {
421        LanguageModelName::from(self.model.display_name().to_string())
422    }
423
424    fn icon(&self) -> Option<IconName> {
425        self.model.icon()
426    }
427
428    fn provider_id(&self) -> LanguageModelProviderId {
429        LanguageModelProviderId(PROVIDER_ID.into())
430    }
431
432    fn provider_name(&self) -> LanguageModelProviderName {
433        LanguageModelProviderName(PROVIDER_NAME.into())
434    }
435
436    fn telemetry_id(&self) -> String {
437        format!("zed.dev/{}", self.model.id())
438    }
439
440    fn availability(&self) -> LanguageModelAvailability {
441        self.model.availability()
442    }
443
444    fn max_token_count(&self) -> usize {
445        self.model.max_token_count()
446    }
447
448    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
449        match &self.model {
450            CloudModel::Anthropic(model) => {
451                model
452                    .cache_configuration()
453                    .map(|cache| LanguageModelCacheConfiguration {
454                        max_cache_anchors: cache.max_cache_anchors,
455                        should_speculate: cache.should_speculate,
456                        min_total_token: cache.min_total_token,
457                    })
458            }
459            CloudModel::OpenAi(_) | CloudModel::Google(_) | CloudModel::Zed(_) => None,
460        }
461    }
462
463    fn count_tokens(
464        &self,
465        request: LanguageModelRequest,
466        cx: &AppContext,
467    ) -> BoxFuture<'static, Result<usize>> {
468        match self.model.clone() {
469            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
470            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
471            CloudModel::Google(model) => {
472                let client = self.client.clone();
473                let request = request.into_google(model.id().into());
474                let request = google_ai::CountTokensRequest {
475                    contents: request.contents,
476                };
477                async move {
478                    let request = serde_json::to_string(&request)?;
479                    let response = client
480                        .request(proto::CountLanguageModelTokens {
481                            provider: proto::LanguageModelProvider::Google as i32,
482                            request,
483                        })
484                        .await?;
485                    Ok(response.token_count as usize)
486                }
487                .boxed()
488            }
489            CloudModel::Zed(_) => {
490                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
491            }
492        }
493    }
494
495    fn stream_completion(
496        &self,
497        request: LanguageModelRequest,
498        _cx: &AsyncAppContext,
499    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
500        match &self.model {
501            CloudModel::Anthropic(model) => {
502                let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
503                let client = self.client.clone();
504                let llm_api_token = self.llm_api_token.clone();
505                let future = self.request_limiter.stream(async move {
506                    let response = Self::perform_llm_completion(
507                        client.clone(),
508                        llm_api_token,
509                        PerformCompletionParams {
510                            provider: client::LanguageModelProvider::Anthropic,
511                            model: request.model.clone(),
512                            provider_request: RawValue::from_string(serde_json::to_string(
513                                &request,
514                            )?)?,
515                        },
516                    )
517                    .await?;
518                    Ok(anthropic::extract_text_from_events(
519                        response_lines(response).map_err(AnthropicError::Other),
520                    ))
521                });
522                async move {
523                    Ok(future
524                        .await?
525                        .map(|result| result.map_err(|err| anyhow!(err)))
526                        .boxed())
527                }
528                .boxed()
529            }
530            CloudModel::OpenAi(model) => {
531                let client = self.client.clone();
532                let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
533                let llm_api_token = self.llm_api_token.clone();
534                let future = self.request_limiter.stream(async move {
535                    let response = Self::perform_llm_completion(
536                        client.clone(),
537                        llm_api_token,
538                        PerformCompletionParams {
539                            provider: client::LanguageModelProvider::OpenAi,
540                            model: request.model.clone(),
541                            provider_request: RawValue::from_string(serde_json::to_string(
542                                &request,
543                            )?)?,
544                        },
545                    )
546                    .await?;
547                    Ok(open_ai::extract_text_from_events(response_lines(response)))
548                });
549                async move { Ok(future.await?.boxed()) }.boxed()
550            }
551            CloudModel::Google(model) => {
552                let client = self.client.clone();
553                let request = request.into_google(model.id().into());
554                let llm_api_token = self.llm_api_token.clone();
555                let future = self.request_limiter.stream(async move {
556                    let response = Self::perform_llm_completion(
557                        client.clone(),
558                        llm_api_token,
559                        PerformCompletionParams {
560                            provider: client::LanguageModelProvider::Google,
561                            model: request.model.clone(),
562                            provider_request: RawValue::from_string(serde_json::to_string(
563                                &request,
564                            )?)?,
565                        },
566                    )
567                    .await?;
568                    Ok(google_ai::extract_text_from_events(response_lines(
569                        response,
570                    )))
571                });
572                async move { Ok(future.await?.boxed()) }.boxed()
573            }
574            CloudModel::Zed(model) => {
575                let client = self.client.clone();
576                let mut request = request.into_open_ai(model.id().into(), None);
577                request.max_tokens = Some(4000);
578                let llm_api_token = self.llm_api_token.clone();
579                let future = self.request_limiter.stream(async move {
580                    let response = Self::perform_llm_completion(
581                        client.clone(),
582                        llm_api_token,
583                        PerformCompletionParams {
584                            provider: client::LanguageModelProvider::Zed,
585                            model: request.model.clone(),
586                            provider_request: RawValue::from_string(serde_json::to_string(
587                                &request,
588                            )?)?,
589                        },
590                    )
591                    .await?;
592                    Ok(open_ai::extract_text_from_events(response_lines(response)))
593                });
594                async move { Ok(future.await?.boxed()) }.boxed()
595            }
596        }
597    }
598
599    fn use_any_tool(
600        &self,
601        request: LanguageModelRequest,
602        tool_name: String,
603        tool_description: String,
604        input_schema: serde_json::Value,
605        _cx: &AsyncAppContext,
606    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
607        let client = self.client.clone();
608        let llm_api_token = self.llm_api_token.clone();
609
610        match &self.model {
611            CloudModel::Anthropic(model) => {
612                let mut request =
613                    request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
614                request.tool_choice = Some(anthropic::ToolChoice::Tool {
615                    name: tool_name.clone(),
616                });
617                request.tools = vec![anthropic::Tool {
618                    name: tool_name.clone(),
619                    description: tool_description,
620                    input_schema,
621                }];
622
623                self.request_limiter
624                    .run(async move {
625                        let response = Self::perform_llm_completion(
626                            client.clone(),
627                            llm_api_token,
628                            PerformCompletionParams {
629                                provider: client::LanguageModelProvider::Anthropic,
630                                model: request.model.clone(),
631                                provider_request: RawValue::from_string(serde_json::to_string(
632                                    &request,
633                                )?)?,
634                            },
635                        )
636                        .await?;
637
638                        Ok(anthropic::extract_tool_args_from_events(
639                            tool_name,
640                            Box::pin(response_lines(response)),
641                        )
642                        .await?
643                        .boxed())
644                    })
645                    .boxed()
646            }
647            CloudModel::OpenAi(model) => {
648                let mut request =
649                    request.into_open_ai(model.id().into(), model.max_output_tokens());
650                request.tool_choice = Some(open_ai::ToolChoice::Other(
651                    open_ai::ToolDefinition::Function {
652                        function: open_ai::FunctionDefinition {
653                            name: tool_name.clone(),
654                            description: None,
655                            parameters: None,
656                        },
657                    },
658                ));
659                request.tools = vec![open_ai::ToolDefinition::Function {
660                    function: open_ai::FunctionDefinition {
661                        name: tool_name.clone(),
662                        description: Some(tool_description),
663                        parameters: Some(input_schema),
664                    },
665                }];
666
667                self.request_limiter
668                    .run(async move {
669                        let response = Self::perform_llm_completion(
670                            client.clone(),
671                            llm_api_token,
672                            PerformCompletionParams {
673                                provider: client::LanguageModelProvider::OpenAi,
674                                model: request.model.clone(),
675                                provider_request: RawValue::from_string(serde_json::to_string(
676                                    &request,
677                                )?)?,
678                            },
679                        )
680                        .await?;
681
682                        Ok(open_ai::extract_tool_args_from_events(
683                            tool_name,
684                            Box::pin(response_lines(response)),
685                        )
686                        .await?
687                        .boxed())
688                    })
689                    .boxed()
690            }
691            CloudModel::Google(_) => {
692                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
693            }
694            CloudModel::Zed(model) => {
695                // All Zed models are OpenAI-based at the time of writing.
696                let mut request = request.into_open_ai(model.id().into(), None);
697                request.tool_choice = Some(open_ai::ToolChoice::Other(
698                    open_ai::ToolDefinition::Function {
699                        function: open_ai::FunctionDefinition {
700                            name: tool_name.clone(),
701                            description: None,
702                            parameters: None,
703                        },
704                    },
705                ));
706                request.tools = vec![open_ai::ToolDefinition::Function {
707                    function: open_ai::FunctionDefinition {
708                        name: tool_name.clone(),
709                        description: Some(tool_description),
710                        parameters: Some(input_schema),
711                    },
712                }];
713
714                self.request_limiter
715                    .run(async move {
716                        let response = Self::perform_llm_completion(
717                            client.clone(),
718                            llm_api_token,
719                            PerformCompletionParams {
720                                provider: client::LanguageModelProvider::Zed,
721                                model: request.model.clone(),
722                                provider_request: RawValue::from_string(serde_json::to_string(
723                                    &request,
724                                )?)?,
725                            },
726                        )
727                        .await?;
728
729                        Ok(open_ai::extract_tool_args_from_events(
730                            tool_name,
731                            Box::pin(response_lines(response)),
732                        )
733                        .await?
734                        .boxed())
735                    })
736                    .boxed()
737            }
738        }
739    }
740}
741
742fn response_lines<T: DeserializeOwned>(
743    response: Response<AsyncBody>,
744) -> impl Stream<Item = Result<T>> {
745    futures::stream::try_unfold(
746        (String::new(), BufReader::new(response.into_body())),
747        move |(mut line, mut body)| async {
748            match body.read_line(&mut line).await {
749                Ok(0) => Ok(None),
750                Ok(_) => {
751                    let event: T = serde_json::from_str(&line)?;
752                    line.clear();
753                    Ok(Some((event, (line, body))))
754                }
755                Err(e) => Err(e.into()),
756            }
757        },
758    )
759}
760
761impl LlmApiToken {
762    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
763        let lock = self.0.upgradable_read().await;
764        if let Some(token) = lock.as_ref() {
765            Ok(token.to_string())
766        } else {
767            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
768        }
769    }
770
771    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
772        Self::fetch(self.0.write().await, &client).await
773    }
774
775    async fn fetch<'a>(
776        mut lock: RwLockWriteGuard<'a, Option<String>>,
777        client: &Arc<Client>,
778    ) -> Result<String> {
779        let response = client.request(proto::GetLlmToken {}).await?;
780        *lock = Some(response.token.clone());
781        Ok(response.token.clone())
782    }
783}
784
785struct ConfigurationView {
786    state: gpui::Model<State>,
787}
788
789impl ConfigurationView {
790    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
791        self.state.update(cx, |state, cx| {
792            state.authenticate(cx).detach_and_log_err(cx);
793        });
794        cx.notify();
795    }
796
797    fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
798        if self.state.read(cx).has_accepted_terms_of_service(cx) {
799            return None;
800        }
801
802        let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
803
804        let terms_button = Button::new("terms_of_service", "Terms of Service")
805            .style(ButtonStyle::Subtle)
806            .icon(IconName::ExternalLink)
807            .icon_color(Color::Muted)
808            .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
809
810        let text =
811            "In order to use Zed AI, please read and accept our terms and conditions to continue:";
812
813        let form = v_flex()
814            .gap_2()
815            .child(Label::new("Terms and Conditions"))
816            .child(Label::new(text))
817            .child(h_flex().justify_center().child(terms_button))
818            .child(
819                h_flex().justify_center().child(
820                    Button::new("accept_terms", "I've read and accept the terms of service")
821                        .style(ButtonStyle::Tinted(TintColor::Accent))
822                        .disabled(accept_terms_disabled)
823                        .on_click({
824                            let state = self.state.downgrade();
825                            move |_, cx| {
826                                state
827                                    .update(cx, |state, cx| state.accept_terms_of_service(cx))
828                                    .ok();
829                            }
830                        }),
831                ),
832            );
833
834        Some(form.into_any())
835    }
836}
837
838impl Render for ConfigurationView {
839    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
840        const ZED_AI_URL: &str = "https://zed.dev/ai";
841        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
842
843        let is_connected = !self.state.read(cx).is_signed_out();
844        let plan = self.state.read(cx).user_store.read(cx).current_plan();
845        let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
846
847        let is_pro = plan == Some(proto::Plan::ZedPro);
848        let subscription_text = Label::new(if is_pro {
849            "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
850        } else {
851            "You have basic access to models from Anthropic through the Zed AI Free plan."
852        });
853        let manage_subscription_button = if is_pro {
854            Some(
855                h_flex().child(
856                    Button::new("manage_settings", "Manage Subscription")
857                        .style(ButtonStyle::Tinted(TintColor::Accent))
858                        .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
859                ),
860            )
861        } else if cx.has_flag::<ZedPro>() {
862            Some(
863                h_flex()
864                    .gap_2()
865                    .child(
866                        Button::new("learn_more", "Learn more")
867                            .style(ButtonStyle::Subtle)
868                            .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
869                    )
870                    .child(
871                        Button::new("upgrade", "Upgrade")
872                            .style(ButtonStyle::Subtle)
873                            .color(Color::Accent)
874                            .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
875                    ),
876            )
877        } else {
878            None
879        };
880
881        if is_connected {
882            v_flex()
883                .gap_3()
884                .max_w_4_5()
885                .children(self.render_accept_terms(cx))
886                .when(has_accepted_terms, |this| {
887                    this.child(subscription_text)
888                        .children(manage_subscription_button)
889                })
890        } else {
891            v_flex()
892                .gap_6()
893                .child(Label::new("Use the zed.dev to access language models."))
894                .child(
895                    v_flex()
896                        .gap_2()
897                        .child(
898                            Button::new("sign_in", "Sign in")
899                                .icon_color(Color::Muted)
900                                .icon(IconName::Github)
901                                .icon_position(IconPosition::Start)
902                                .style(ButtonStyle::Filled)
903                                .full_width()
904                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
905                        )
906                        .child(
907                            div().flex().w_full().items_center().child(
908                                Label::new("Sign in to enable collaboration.")
909                                    .color(Color::Muted)
910                                    .size(LabelSize::Small),
911                            ),
912                        ),
913                )
914        }
915    }
916}