1use super::open_ai::count_open_ai_tokens;
2use crate::{
3 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
4 LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
5 LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
6};
7use anthropic::AnthropicError;
8use anyhow::{anyhow, Result};
9use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
10use collections::BTreeMap;
11use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
12use futures::{
13 future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
14 TryStreamExt as _,
15};
16use gpui::{
17 AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
18 Subscription, Task,
19};
20use http_client::{AsyncBody, HttpClient, Method, Response};
21use schemars::JsonSchema;
22use serde::{de::DeserializeOwned, Deserialize, Serialize};
23use serde_json::value::RawValue;
24use settings::{Settings, SettingsStore};
25use smol::{
26 io::{AsyncReadExt, BufReader},
27 lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
28};
29use std::{
30 future,
31 sync::{Arc, LazyLock},
32};
33use strum::IntoEnumIterator;
34use ui::{prelude::*, TintColor};
35
36use crate::{LanguageModelAvailability, LanguageModelProvider};
37
38use super::anthropic::count_anthropic_tokens;
39
40pub const PROVIDER_ID: &str = "zed.dev";
41pub const PROVIDER_NAME: &str = "Zed";
42
43const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
44 option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
45
46fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
47 static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
48 ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
49 .map(|json| serde_json::from_str(json).unwrap())
50 .unwrap_or(Vec::new())
51 });
52 ADDITIONAL_MODELS.as_slice()
53}
54
55#[derive(Default, Clone, Debug, PartialEq)]
56pub struct ZedDotDevSettings {
57 pub available_models: Vec<AvailableModel>,
58}
59
60#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
61#[serde(rename_all = "lowercase")]
62pub enum AvailableProvider {
63 Anthropic,
64 OpenAi,
65 Google,
66}
67
68#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
69pub struct AvailableModel {
70 /// The provider of the language model.
71 pub provider: AvailableProvider,
72 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
73 pub name: String,
74 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
75 pub display_name: Option<String>,
76 /// The size of the context window, indicating the maximum number of tokens the model can process.
77 pub max_tokens: usize,
78 /// The maximum number of output tokens allowed by the model.
79 pub max_output_tokens: Option<u32>,
80 /// Override this model with a different Anthropic model for tool calls.
81 pub tool_override: Option<String>,
82 /// Indicates whether this custom model supports caching.
83 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
84}
85
86pub struct CloudLanguageModelProvider {
87 client: Arc<Client>,
88 llm_api_token: LlmApiToken,
89 state: gpui::Model<State>,
90 _maintain_client_status: Task<()>,
91}
92
93pub struct State {
94 client: Arc<Client>,
95 user_store: Model<UserStore>,
96 status: client::Status,
97 accept_terms: Option<Task<Result<()>>>,
98 _subscription: Subscription,
99}
100
101impl State {
102 fn is_signed_out(&self) -> bool {
103 self.status.is_signed_out()
104 }
105
106 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
107 let client = self.client.clone();
108 cx.spawn(move |this, mut cx| async move {
109 client.authenticate_and_connect(true, &cx).await?;
110 this.update(&mut cx, |_, cx| cx.notify())
111 })
112 }
113
114 fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
115 self.user_store
116 .read(cx)
117 .current_user_has_accepted_terms()
118 .unwrap_or(false)
119 }
120
121 fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
122 let user_store = self.user_store.clone();
123 self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
124 let _ = user_store
125 .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
126 .await;
127 this.update(&mut cx, |this, cx| {
128 this.accept_terms = None;
129 cx.notify()
130 })
131 }));
132 }
133}
134
135impl CloudLanguageModelProvider {
136 pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
137 let mut status_rx = client.status();
138 let status = *status_rx.borrow();
139
140 let state = cx.new_model(|cx| State {
141 client: client.clone(),
142 user_store,
143 status,
144 accept_terms: None,
145 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
146 cx.notify();
147 }),
148 });
149
150 let state_ref = state.downgrade();
151 let maintain_client_status = cx.spawn(|mut cx| async move {
152 while let Some(status) = status_rx.next().await {
153 if let Some(this) = state_ref.upgrade() {
154 _ = this.update(&mut cx, |this, cx| {
155 if this.status != status {
156 this.status = status;
157 cx.notify();
158 }
159 });
160 } else {
161 break;
162 }
163 }
164 });
165
166 Self {
167 client,
168 state,
169 llm_api_token: LlmApiToken::default(),
170 _maintain_client_status: maintain_client_status,
171 }
172 }
173}
174
175impl LanguageModelProviderState for CloudLanguageModelProvider {
176 type ObservableEntity = State;
177
178 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
179 Some(self.state.clone())
180 }
181}
182
183impl LanguageModelProvider for CloudLanguageModelProvider {
184 fn id(&self) -> LanguageModelProviderId {
185 LanguageModelProviderId(PROVIDER_ID.into())
186 }
187
188 fn name(&self) -> LanguageModelProviderName {
189 LanguageModelProviderName(PROVIDER_NAME.into())
190 }
191
192 fn icon(&self) -> IconName {
193 IconName::AiZed
194 }
195
196 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
197 let mut models = BTreeMap::default();
198
199 if cx.is_staff() {
200 for model in anthropic::Model::iter() {
201 if !matches!(model, anthropic::Model::Custom { .. }) {
202 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
203 }
204 }
205 for model in open_ai::Model::iter() {
206 if !matches!(model, open_ai::Model::Custom { .. }) {
207 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
208 }
209 }
210 for model in google_ai::Model::iter() {
211 if !matches!(model, google_ai::Model::Custom { .. }) {
212 models.insert(model.id().to_string(), CloudModel::Google(model));
213 }
214 }
215 for model in ZedModel::iter() {
216 models.insert(model.id().to_string(), CloudModel::Zed(model));
217 }
218 } else {
219 models.insert(
220 anthropic::Model::Claude3_5Sonnet.id().to_string(),
221 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
222 );
223 }
224
225 let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
226 zed_cloud_provider_additional_models()
227 } else {
228 &[]
229 };
230
231 // Override with available models from settings
232 for model in AllLanguageModelSettings::get_global(cx)
233 .zed_dot_dev
234 .available_models
235 .iter()
236 .chain(llm_closed_beta_models)
237 .cloned()
238 {
239 let model = match model.provider {
240 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
241 name: model.name.clone(),
242 display_name: model.display_name.clone(),
243 max_tokens: model.max_tokens,
244 tool_override: model.tool_override.clone(),
245 cache_configuration: model.cache_configuration.as_ref().map(|config| {
246 anthropic::AnthropicModelCacheConfiguration {
247 max_cache_anchors: config.max_cache_anchors,
248 should_speculate: config.should_speculate,
249 min_total_token: config.min_total_token,
250 }
251 }),
252 max_output_tokens: model.max_output_tokens,
253 }),
254 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
255 name: model.name.clone(),
256 max_tokens: model.max_tokens,
257 max_output_tokens: model.max_output_tokens,
258 }),
259 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
260 name: model.name.clone(),
261 max_tokens: model.max_tokens,
262 }),
263 };
264 models.insert(model.id().to_string(), model.clone());
265 }
266
267 models
268 .into_values()
269 .map(|model| {
270 Arc::new(CloudLanguageModel {
271 id: LanguageModelId::from(model.id().to_string()),
272 model,
273 llm_api_token: self.llm_api_token.clone(),
274 client: self.client.clone(),
275 request_limiter: RateLimiter::new(4),
276 }) as Arc<dyn LanguageModel>
277 })
278 .collect()
279 }
280
281 fn is_authenticated(&self, cx: &AppContext) -> bool {
282 !self.state.read(cx).is_signed_out()
283 }
284
285 fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
286 Task::ready(Ok(()))
287 }
288
289 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
290 cx.new_view(|_cx| ConfigurationView {
291 state: self.state.clone(),
292 })
293 .into()
294 }
295
296 fn must_accept_terms(&self, cx: &AppContext) -> bool {
297 !self.state.read(cx).has_accepted_terms_of_service(cx)
298 }
299
300 fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
301 let state = self.state.read(cx);
302
303 let terms = [(
304 "terms_of_service",
305 "Terms of Service",
306 "https://zed.dev/terms-of-service",
307 )]
308 .map(|(id, label, url)| {
309 Button::new(id, label)
310 .style(ButtonStyle::Subtle)
311 .icon(IconName::ExternalLink)
312 .icon_size(IconSize::XSmall)
313 .icon_color(Color::Muted)
314 .on_click(move |_, cx| cx.open_url(url))
315 });
316
317 if state.has_accepted_terms_of_service(cx) {
318 None
319 } else {
320 let disabled = state.accept_terms.is_some();
321 Some(
322 v_flex()
323 .gap_2()
324 .child(
325 v_flex()
326 .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
327 .child(
328 Label::new(
329 "Please read and accept our terms and conditions to continue.",
330 )
331 .size(LabelSize::Small),
332 ),
333 )
334 .child(v_flex().gap_1().children(terms))
335 .child(
336 h_flex().justify_end().child(
337 Button::new("accept_terms", "I've read it and accept it")
338 .disabled(disabled)
339 .on_click({
340 let state = self.state.downgrade();
341 move |_, cx| {
342 state
343 .update(cx, |state, cx| {
344 state.accept_terms_of_service(cx)
345 })
346 .ok();
347 }
348 }),
349 ),
350 )
351 .into_any(),
352 )
353 }
354 }
355
356 fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
357 Task::ready(Ok(()))
358 }
359}
360
361pub struct CloudLanguageModel {
362 id: LanguageModelId,
363 model: CloudModel,
364 llm_api_token: LlmApiToken,
365 client: Arc<Client>,
366 request_limiter: RateLimiter,
367}
368
369#[derive(Clone, Default)]
370struct LlmApiToken(Arc<RwLock<Option<String>>>);
371
372impl CloudLanguageModel {
373 async fn perform_llm_completion(
374 client: Arc<Client>,
375 llm_api_token: LlmApiToken,
376 body: PerformCompletionParams,
377 ) -> Result<Response<AsyncBody>> {
378 let http_client = &client.http_client();
379
380 let mut token = llm_api_token.acquire(&client).await?;
381 let mut did_retry = false;
382
383 let response = loop {
384 let request = http_client::Request::builder()
385 .method(Method::POST)
386 .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
387 .header("Content-Type", "application/json")
388 .header("Authorization", format!("Bearer {token}"))
389 .body(serde_json::to_string(&body)?.into())?;
390 let mut response = http_client.send(request).await?;
391 if response.status().is_success() {
392 break response;
393 } else if !did_retry
394 && response
395 .headers()
396 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
397 .is_some()
398 {
399 did_retry = true;
400 token = llm_api_token.refresh(&client).await?;
401 } else {
402 let mut body = String::new();
403 response.body_mut().read_to_string(&mut body).await?;
404 break Err(anyhow!(
405 "cloud language model completion failed with status {}: {body}",
406 response.status()
407 ))?;
408 }
409 };
410
411 Ok(response)
412 }
413}
414
415impl LanguageModel for CloudLanguageModel {
416 fn id(&self) -> LanguageModelId {
417 self.id.clone()
418 }
419
420 fn name(&self) -> LanguageModelName {
421 LanguageModelName::from(self.model.display_name().to_string())
422 }
423
424 fn icon(&self) -> Option<IconName> {
425 self.model.icon()
426 }
427
428 fn provider_id(&self) -> LanguageModelProviderId {
429 LanguageModelProviderId(PROVIDER_ID.into())
430 }
431
432 fn provider_name(&self) -> LanguageModelProviderName {
433 LanguageModelProviderName(PROVIDER_NAME.into())
434 }
435
436 fn telemetry_id(&self) -> String {
437 format!("zed.dev/{}", self.model.id())
438 }
439
440 fn availability(&self) -> LanguageModelAvailability {
441 self.model.availability()
442 }
443
444 fn max_token_count(&self) -> usize {
445 self.model.max_token_count()
446 }
447
448 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
449 match &self.model {
450 CloudModel::Anthropic(model) => {
451 model
452 .cache_configuration()
453 .map(|cache| LanguageModelCacheConfiguration {
454 max_cache_anchors: cache.max_cache_anchors,
455 should_speculate: cache.should_speculate,
456 min_total_token: cache.min_total_token,
457 })
458 }
459 CloudModel::OpenAi(_) | CloudModel::Google(_) | CloudModel::Zed(_) => None,
460 }
461 }
462
463 fn count_tokens(
464 &self,
465 request: LanguageModelRequest,
466 cx: &AppContext,
467 ) -> BoxFuture<'static, Result<usize>> {
468 match self.model.clone() {
469 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
470 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
471 CloudModel::Google(model) => {
472 let client = self.client.clone();
473 let request = request.into_google(model.id().into());
474 let request = google_ai::CountTokensRequest {
475 contents: request.contents,
476 };
477 async move {
478 let request = serde_json::to_string(&request)?;
479 let response = client
480 .request(proto::CountLanguageModelTokens {
481 provider: proto::LanguageModelProvider::Google as i32,
482 request,
483 })
484 .await?;
485 Ok(response.token_count as usize)
486 }
487 .boxed()
488 }
489 CloudModel::Zed(_) => {
490 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
491 }
492 }
493 }
494
495 fn stream_completion(
496 &self,
497 request: LanguageModelRequest,
498 _cx: &AsyncAppContext,
499 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
500 match &self.model {
501 CloudModel::Anthropic(model) => {
502 let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
503 let client = self.client.clone();
504 let llm_api_token = self.llm_api_token.clone();
505 let future = self.request_limiter.stream(async move {
506 let response = Self::perform_llm_completion(
507 client.clone(),
508 llm_api_token,
509 PerformCompletionParams {
510 provider: client::LanguageModelProvider::Anthropic,
511 model: request.model.clone(),
512 provider_request: RawValue::from_string(serde_json::to_string(
513 &request,
514 )?)?,
515 },
516 )
517 .await?;
518 Ok(anthropic::extract_text_from_events(
519 response_lines(response).map_err(AnthropicError::Other),
520 ))
521 });
522 async move {
523 Ok(future
524 .await?
525 .map(|result| result.map_err(|err| anyhow!(err)))
526 .boxed())
527 }
528 .boxed()
529 }
530 CloudModel::OpenAi(model) => {
531 let client = self.client.clone();
532 let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
533 let llm_api_token = self.llm_api_token.clone();
534 let future = self.request_limiter.stream(async move {
535 let response = Self::perform_llm_completion(
536 client.clone(),
537 llm_api_token,
538 PerformCompletionParams {
539 provider: client::LanguageModelProvider::OpenAi,
540 model: request.model.clone(),
541 provider_request: RawValue::from_string(serde_json::to_string(
542 &request,
543 )?)?,
544 },
545 )
546 .await?;
547 Ok(open_ai::extract_text_from_events(response_lines(response)))
548 });
549 async move { Ok(future.await?.boxed()) }.boxed()
550 }
551 CloudModel::Google(model) => {
552 let client = self.client.clone();
553 let request = request.into_google(model.id().into());
554 let llm_api_token = self.llm_api_token.clone();
555 let future = self.request_limiter.stream(async move {
556 let response = Self::perform_llm_completion(
557 client.clone(),
558 llm_api_token,
559 PerformCompletionParams {
560 provider: client::LanguageModelProvider::Google,
561 model: request.model.clone(),
562 provider_request: RawValue::from_string(serde_json::to_string(
563 &request,
564 )?)?,
565 },
566 )
567 .await?;
568 Ok(google_ai::extract_text_from_events(response_lines(
569 response,
570 )))
571 });
572 async move { Ok(future.await?.boxed()) }.boxed()
573 }
574 CloudModel::Zed(model) => {
575 let client = self.client.clone();
576 let mut request = request.into_open_ai(model.id().into(), None);
577 request.max_tokens = Some(4000);
578 let llm_api_token = self.llm_api_token.clone();
579 let future = self.request_limiter.stream(async move {
580 let response = Self::perform_llm_completion(
581 client.clone(),
582 llm_api_token,
583 PerformCompletionParams {
584 provider: client::LanguageModelProvider::Zed,
585 model: request.model.clone(),
586 provider_request: RawValue::from_string(serde_json::to_string(
587 &request,
588 )?)?,
589 },
590 )
591 .await?;
592 Ok(open_ai::extract_text_from_events(response_lines(response)))
593 });
594 async move { Ok(future.await?.boxed()) }.boxed()
595 }
596 }
597 }
598
599 fn use_any_tool(
600 &self,
601 request: LanguageModelRequest,
602 tool_name: String,
603 tool_description: String,
604 input_schema: serde_json::Value,
605 _cx: &AsyncAppContext,
606 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
607 let client = self.client.clone();
608 let llm_api_token = self.llm_api_token.clone();
609
610 match &self.model {
611 CloudModel::Anthropic(model) => {
612 let mut request =
613 request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
614 request.tool_choice = Some(anthropic::ToolChoice::Tool {
615 name: tool_name.clone(),
616 });
617 request.tools = vec![anthropic::Tool {
618 name: tool_name.clone(),
619 description: tool_description,
620 input_schema,
621 }];
622
623 self.request_limiter
624 .run(async move {
625 let response = Self::perform_llm_completion(
626 client.clone(),
627 llm_api_token,
628 PerformCompletionParams {
629 provider: client::LanguageModelProvider::Anthropic,
630 model: request.model.clone(),
631 provider_request: RawValue::from_string(serde_json::to_string(
632 &request,
633 )?)?,
634 },
635 )
636 .await?;
637
638 Ok(anthropic::extract_tool_args_from_events(
639 tool_name,
640 Box::pin(response_lines(response)),
641 )
642 .await?
643 .boxed())
644 })
645 .boxed()
646 }
647 CloudModel::OpenAi(model) => {
648 let mut request =
649 request.into_open_ai(model.id().into(), model.max_output_tokens());
650 request.tool_choice = Some(open_ai::ToolChoice::Other(
651 open_ai::ToolDefinition::Function {
652 function: open_ai::FunctionDefinition {
653 name: tool_name.clone(),
654 description: None,
655 parameters: None,
656 },
657 },
658 ));
659 request.tools = vec![open_ai::ToolDefinition::Function {
660 function: open_ai::FunctionDefinition {
661 name: tool_name.clone(),
662 description: Some(tool_description),
663 parameters: Some(input_schema),
664 },
665 }];
666
667 self.request_limiter
668 .run(async move {
669 let response = Self::perform_llm_completion(
670 client.clone(),
671 llm_api_token,
672 PerformCompletionParams {
673 provider: client::LanguageModelProvider::OpenAi,
674 model: request.model.clone(),
675 provider_request: RawValue::from_string(serde_json::to_string(
676 &request,
677 )?)?,
678 },
679 )
680 .await?;
681
682 Ok(open_ai::extract_tool_args_from_events(
683 tool_name,
684 Box::pin(response_lines(response)),
685 )
686 .await?
687 .boxed())
688 })
689 .boxed()
690 }
691 CloudModel::Google(_) => {
692 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
693 }
694 CloudModel::Zed(model) => {
695 // All Zed models are OpenAI-based at the time of writing.
696 let mut request = request.into_open_ai(model.id().into(), None);
697 request.tool_choice = Some(open_ai::ToolChoice::Other(
698 open_ai::ToolDefinition::Function {
699 function: open_ai::FunctionDefinition {
700 name: tool_name.clone(),
701 description: None,
702 parameters: None,
703 },
704 },
705 ));
706 request.tools = vec![open_ai::ToolDefinition::Function {
707 function: open_ai::FunctionDefinition {
708 name: tool_name.clone(),
709 description: Some(tool_description),
710 parameters: Some(input_schema),
711 },
712 }];
713
714 self.request_limiter
715 .run(async move {
716 let response = Self::perform_llm_completion(
717 client.clone(),
718 llm_api_token,
719 PerformCompletionParams {
720 provider: client::LanguageModelProvider::Zed,
721 model: request.model.clone(),
722 provider_request: RawValue::from_string(serde_json::to_string(
723 &request,
724 )?)?,
725 },
726 )
727 .await?;
728
729 Ok(open_ai::extract_tool_args_from_events(
730 tool_name,
731 Box::pin(response_lines(response)),
732 )
733 .await?
734 .boxed())
735 })
736 .boxed()
737 }
738 }
739 }
740}
741
742fn response_lines<T: DeserializeOwned>(
743 response: Response<AsyncBody>,
744) -> impl Stream<Item = Result<T>> {
745 futures::stream::try_unfold(
746 (String::new(), BufReader::new(response.into_body())),
747 move |(mut line, mut body)| async {
748 match body.read_line(&mut line).await {
749 Ok(0) => Ok(None),
750 Ok(_) => {
751 let event: T = serde_json::from_str(&line)?;
752 line.clear();
753 Ok(Some((event, (line, body))))
754 }
755 Err(e) => Err(e.into()),
756 }
757 },
758 )
759}
760
761impl LlmApiToken {
762 async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
763 let lock = self.0.upgradable_read().await;
764 if let Some(token) = lock.as_ref() {
765 Ok(token.to_string())
766 } else {
767 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
768 }
769 }
770
771 async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
772 Self::fetch(self.0.write().await, &client).await
773 }
774
775 async fn fetch<'a>(
776 mut lock: RwLockWriteGuard<'a, Option<String>>,
777 client: &Arc<Client>,
778 ) -> Result<String> {
779 let response = client.request(proto::GetLlmToken {}).await?;
780 *lock = Some(response.token.clone());
781 Ok(response.token.clone())
782 }
783}
784
785struct ConfigurationView {
786 state: gpui::Model<State>,
787}
788
789impl ConfigurationView {
790 fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
791 self.state.update(cx, |state, cx| {
792 state.authenticate(cx).detach_and_log_err(cx);
793 });
794 cx.notify();
795 }
796
797 fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
798 if self.state.read(cx).has_accepted_terms_of_service(cx) {
799 return None;
800 }
801
802 let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
803
804 let terms_button = Button::new("terms_of_service", "Terms of Service")
805 .style(ButtonStyle::Subtle)
806 .icon(IconName::ExternalLink)
807 .icon_color(Color::Muted)
808 .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
809
810 let text =
811 "In order to use Zed AI, please read and accept our terms and conditions to continue:";
812
813 let form = v_flex()
814 .gap_2()
815 .child(Label::new("Terms and Conditions"))
816 .child(Label::new(text))
817 .child(h_flex().justify_center().child(terms_button))
818 .child(
819 h_flex().justify_center().child(
820 Button::new("accept_terms", "I've read and accept the terms of service")
821 .style(ButtonStyle::Tinted(TintColor::Accent))
822 .disabled(accept_terms_disabled)
823 .on_click({
824 let state = self.state.downgrade();
825 move |_, cx| {
826 state
827 .update(cx, |state, cx| state.accept_terms_of_service(cx))
828 .ok();
829 }
830 }),
831 ),
832 );
833
834 Some(form.into_any())
835 }
836}
837
838impl Render for ConfigurationView {
839 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
840 const ZED_AI_URL: &str = "https://zed.dev/ai";
841 const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
842
843 let is_connected = !self.state.read(cx).is_signed_out();
844 let plan = self.state.read(cx).user_store.read(cx).current_plan();
845 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
846
847 let is_pro = plan == Some(proto::Plan::ZedPro);
848 let subscription_text = Label::new(if is_pro {
849 "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
850 } else {
851 "You have basic access to models from Anthropic through the Zed AI Free plan."
852 });
853 let manage_subscription_button = if is_pro {
854 Some(
855 h_flex().child(
856 Button::new("manage_settings", "Manage Subscription")
857 .style(ButtonStyle::Tinted(TintColor::Accent))
858 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
859 ),
860 )
861 } else if cx.has_flag::<ZedPro>() {
862 Some(
863 h_flex()
864 .gap_2()
865 .child(
866 Button::new("learn_more", "Learn more")
867 .style(ButtonStyle::Subtle)
868 .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
869 )
870 .child(
871 Button::new("upgrade", "Upgrade")
872 .style(ButtonStyle::Subtle)
873 .color(Color::Accent)
874 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
875 ),
876 )
877 } else {
878 None
879 };
880
881 if is_connected {
882 v_flex()
883 .gap_3()
884 .max_w_4_5()
885 .children(self.render_accept_terms(cx))
886 .when(has_accepted_terms, |this| {
887 this.child(subscription_text)
888 .children(manage_subscription_button)
889 })
890 } else {
891 v_flex()
892 .gap_6()
893 .child(Label::new("Use the zed.dev to access language models."))
894 .child(
895 v_flex()
896 .gap_2()
897 .child(
898 Button::new("sign_in", "Sign in")
899 .icon_color(Color::Muted)
900 .icon(IconName::Github)
901 .icon_position(IconPosition::Start)
902 .style(ButtonStyle::Filled)
903 .full_width()
904 .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
905 )
906 .child(
907 div().flex().w_full().items_center().child(
908 Label::new("Sign in to enable collaboration.")
909 .color(Color::Muted)
910 .size(LabelSize::Small),
911 ),
912 ),
913 )
914 }
915 }
916}