1use super::open_ai::count_open_ai_tokens;
2use crate::{
3 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
4 LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
5 LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
6};
7use anthropic::AnthropicError;
8use anyhow::{anyhow, Result};
9use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
10use collections::BTreeMap;
11use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
12use futures::{
13 future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
14 TryStreamExt as _,
15};
16use gpui::{
17 AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
18 Subscription, Task,
19};
20use http_client::{AsyncBody, HttpClient, Method, Response};
21use schemars::JsonSchema;
22use serde::{de::DeserializeOwned, Deserialize, Serialize};
23use serde_json::value::RawValue;
24use settings::{Settings, SettingsStore};
25use smol::{
26 io::{AsyncReadExt, BufReader},
27 lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
28};
29use std::{
30 future,
31 sync::{Arc, LazyLock},
32};
33use strum::IntoEnumIterator;
34use ui::prelude::*;
35
36use crate::{LanguageModelAvailability, LanguageModelProvider};
37
38use super::anthropic::count_anthropic_tokens;
39
40pub const PROVIDER_ID: &str = "zed.dev";
41pub const PROVIDER_NAME: &str = "Zed";
42
43const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
44 option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
45
46fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
47 static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
48 ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
49 .map(|json| serde_json::from_str(json).unwrap())
50 .unwrap_or(Vec::new())
51 });
52 ADDITIONAL_MODELS.as_slice()
53}
54
55#[derive(Default, Clone, Debug, PartialEq)]
56pub struct ZedDotDevSettings {
57 pub available_models: Vec<AvailableModel>,
58}
59
60#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
61#[serde(rename_all = "lowercase")]
62pub enum AvailableProvider {
63 Anthropic,
64 OpenAi,
65 Google,
66}
67
68#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
69pub struct AvailableModel {
70 /// The provider of the language model.
71 pub provider: AvailableProvider,
72 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
73 pub name: String,
74 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
75 pub display_name: Option<String>,
76 /// The size of the context window, indicating the maximum number of tokens the model can process.
77 pub max_tokens: usize,
78 /// The maximum number of output tokens allowed by the model.
79 pub max_output_tokens: Option<u32>,
80 /// Override this model with a different Anthropic model for tool calls.
81 pub tool_override: Option<String>,
82 /// Indicates whether this custom model supports caching.
83 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
84}
85
86pub struct CloudLanguageModelProvider {
87 client: Arc<Client>,
88 llm_api_token: LlmApiToken,
89 state: gpui::Model<State>,
90 _maintain_client_status: Task<()>,
91}
92
93pub struct State {
94 client: Arc<Client>,
95 user_store: Model<UserStore>,
96 status: client::Status,
97 accept_terms: Option<Task<Result<()>>>,
98 _subscription: Subscription,
99}
100
101impl State {
102 fn is_signed_out(&self) -> bool {
103 self.status.is_signed_out()
104 }
105
106 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
107 let client = self.client.clone();
108 cx.spawn(move |this, mut cx| async move {
109 client.authenticate_and_connect(true, &cx).await?;
110 this.update(&mut cx, |_, cx| cx.notify())
111 })
112 }
113
114 fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
115 self.user_store
116 .read(cx)
117 .current_user_has_accepted_terms()
118 .unwrap_or(false)
119 }
120
121 fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
122 let user_store = self.user_store.clone();
123 self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
124 let _ = user_store
125 .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
126 .await;
127 this.update(&mut cx, |this, cx| {
128 this.accept_terms = None;
129 cx.notify()
130 })
131 }));
132 }
133}
134
135impl CloudLanguageModelProvider {
136 pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
137 let mut status_rx = client.status();
138 let status = *status_rx.borrow();
139
140 let state = cx.new_model(|cx| State {
141 client: client.clone(),
142 user_store,
143 status,
144 accept_terms: None,
145 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
146 cx.notify();
147 }),
148 });
149
150 let state_ref = state.downgrade();
151 let maintain_client_status = cx.spawn(|mut cx| async move {
152 while let Some(status) = status_rx.next().await {
153 if let Some(this) = state_ref.upgrade() {
154 _ = this.update(&mut cx, |this, cx| {
155 if this.status != status {
156 this.status = status;
157 cx.notify();
158 }
159 });
160 } else {
161 break;
162 }
163 }
164 });
165
166 Self {
167 client,
168 state,
169 llm_api_token: LlmApiToken::default(),
170 _maintain_client_status: maintain_client_status,
171 }
172 }
173}
174
175impl LanguageModelProviderState for CloudLanguageModelProvider {
176 type ObservableEntity = State;
177
178 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
179 Some(self.state.clone())
180 }
181}
182
183impl LanguageModelProvider for CloudLanguageModelProvider {
184 fn id(&self) -> LanguageModelProviderId {
185 LanguageModelProviderId(PROVIDER_ID.into())
186 }
187
188 fn name(&self) -> LanguageModelProviderName {
189 LanguageModelProviderName(PROVIDER_NAME.into())
190 }
191
192 fn icon(&self) -> IconName {
193 IconName::AiZed
194 }
195
196 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
197 let mut models = BTreeMap::default();
198
199 if cx.is_staff() {
200 for model in anthropic::Model::iter() {
201 if !matches!(model, anthropic::Model::Custom { .. }) {
202 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
203 }
204 }
205 for model in open_ai::Model::iter() {
206 if !matches!(model, open_ai::Model::Custom { .. }) {
207 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
208 }
209 }
210 for model in google_ai::Model::iter() {
211 if !matches!(model, google_ai::Model::Custom { .. }) {
212 models.insert(model.id().to_string(), CloudModel::Google(model));
213 }
214 }
215 for model in ZedModel::iter() {
216 models.insert(model.id().to_string(), CloudModel::Zed(model));
217 }
218 } else {
219 models.insert(
220 anthropic::Model::Claude3_5Sonnet.id().to_string(),
221 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
222 );
223 }
224
225 let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
226 zed_cloud_provider_additional_models()
227 } else {
228 &[]
229 };
230
231 // Override with available models from settings
232 for model in AllLanguageModelSettings::get_global(cx)
233 .zed_dot_dev
234 .available_models
235 .iter()
236 .chain(llm_closed_beta_models)
237 .cloned()
238 {
239 let model = match model.provider {
240 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
241 name: model.name.clone(),
242 display_name: model.display_name.clone(),
243 max_tokens: model.max_tokens,
244 tool_override: model.tool_override.clone(),
245 cache_configuration: model.cache_configuration.as_ref().map(|config| {
246 anthropic::AnthropicModelCacheConfiguration {
247 max_cache_anchors: config.max_cache_anchors,
248 should_speculate: config.should_speculate,
249 min_total_token: config.min_total_token,
250 }
251 }),
252 max_output_tokens: model.max_output_tokens,
253 }),
254 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
255 name: model.name.clone(),
256 max_tokens: model.max_tokens,
257 }),
258 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
259 name: model.name.clone(),
260 max_tokens: model.max_tokens,
261 }),
262 };
263 models.insert(model.id().to_string(), model.clone());
264 }
265
266 models
267 .into_values()
268 .map(|model| {
269 Arc::new(CloudLanguageModel {
270 id: LanguageModelId::from(model.id().to_string()),
271 model,
272 llm_api_token: self.llm_api_token.clone(),
273 client: self.client.clone(),
274 request_limiter: RateLimiter::new(4),
275 }) as Arc<dyn LanguageModel>
276 })
277 .collect()
278 }
279
280 fn is_authenticated(&self, cx: &AppContext) -> bool {
281 !self.state.read(cx).is_signed_out()
282 }
283
284 fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
285 Task::ready(Ok(()))
286 }
287
288 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
289 cx.new_view(|_cx| ConfigurationView {
290 state: self.state.clone(),
291 })
292 .into()
293 }
294
295 fn must_accept_terms(&self, cx: &AppContext) -> bool {
296 !self.state.read(cx).has_accepted_terms_of_service(cx)
297 }
298
299 fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
300 let state = self.state.read(cx);
301
302 let terms = [(
303 "terms_of_service",
304 "Terms of Service",
305 "https://zed.dev/terms-of-service",
306 )]
307 .map(|(id, label, url)| {
308 Button::new(id, label)
309 .style(ButtonStyle::Subtle)
310 .icon(IconName::ExternalLink)
311 .icon_size(IconSize::XSmall)
312 .icon_color(Color::Muted)
313 .on_click(move |_, cx| cx.open_url(url))
314 });
315
316 if state.has_accepted_terms_of_service(cx) {
317 None
318 } else {
319 let disabled = state.accept_terms.is_some();
320 Some(
321 v_flex()
322 .gap_2()
323 .child(
324 v_flex()
325 .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
326 .child(
327 Label::new(
328 "Please read and accept our terms and conditions to continue.",
329 )
330 .size(LabelSize::Small),
331 ),
332 )
333 .child(v_flex().gap_1().children(terms))
334 .child(
335 h_flex().justify_end().child(
336 Button::new("accept_terms", "I've read it and accept it")
337 .disabled(disabled)
338 .on_click({
339 let state = self.state.downgrade();
340 move |_, cx| {
341 state
342 .update(cx, |state, cx| {
343 state.accept_terms_of_service(cx)
344 })
345 .ok();
346 }
347 }),
348 ),
349 )
350 .into_any(),
351 )
352 }
353 }
354
355 fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
356 Task::ready(Ok(()))
357 }
358}
359
360pub struct CloudLanguageModel {
361 id: LanguageModelId,
362 model: CloudModel,
363 llm_api_token: LlmApiToken,
364 client: Arc<Client>,
365 request_limiter: RateLimiter,
366}
367
368#[derive(Clone, Default)]
369struct LlmApiToken(Arc<RwLock<Option<String>>>);
370
371impl CloudLanguageModel {
372 async fn perform_llm_completion(
373 client: Arc<Client>,
374 llm_api_token: LlmApiToken,
375 body: PerformCompletionParams,
376 ) -> Result<Response<AsyncBody>> {
377 let http_client = &client.http_client();
378
379 let mut token = llm_api_token.acquire(&client).await?;
380 let mut did_retry = false;
381
382 let response = loop {
383 let request = http_client::Request::builder()
384 .method(Method::POST)
385 .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
386 .header("Content-Type", "application/json")
387 .header("Authorization", format!("Bearer {token}"))
388 .body(serde_json::to_string(&body)?.into())?;
389 let mut response = http_client.send(request).await?;
390 if response.status().is_success() {
391 break response;
392 } else if !did_retry
393 && response
394 .headers()
395 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
396 .is_some()
397 {
398 did_retry = true;
399 token = llm_api_token.refresh(&client).await?;
400 } else {
401 let mut body = String::new();
402 response.body_mut().read_to_string(&mut body).await?;
403 break Err(anyhow!(
404 "cloud language model completion failed with status {}: {body}",
405 response.status()
406 ))?;
407 }
408 };
409
410 Ok(response)
411 }
412}
413
414impl LanguageModel for CloudLanguageModel {
415 fn id(&self) -> LanguageModelId {
416 self.id.clone()
417 }
418
419 fn name(&self) -> LanguageModelName {
420 LanguageModelName::from(self.model.display_name().to_string())
421 }
422
423 fn icon(&self) -> Option<IconName> {
424 self.model.icon()
425 }
426
427 fn provider_id(&self) -> LanguageModelProviderId {
428 LanguageModelProviderId(PROVIDER_ID.into())
429 }
430
431 fn provider_name(&self) -> LanguageModelProviderName {
432 LanguageModelProviderName(PROVIDER_NAME.into())
433 }
434
435 fn telemetry_id(&self) -> String {
436 format!("zed.dev/{}", self.model.id())
437 }
438
439 fn availability(&self) -> LanguageModelAvailability {
440 self.model.availability()
441 }
442
443 fn max_token_count(&self) -> usize {
444 self.model.max_token_count()
445 }
446
447 fn count_tokens(
448 &self,
449 request: LanguageModelRequest,
450 cx: &AppContext,
451 ) -> BoxFuture<'static, Result<usize>> {
452 match self.model.clone() {
453 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
454 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
455 CloudModel::Google(model) => {
456 let client = self.client.clone();
457 let request = request.into_google(model.id().into());
458 let request = google_ai::CountTokensRequest {
459 contents: request.contents,
460 };
461 async move {
462 let request = serde_json::to_string(&request)?;
463 let response = client
464 .request(proto::CountLanguageModelTokens {
465 provider: proto::LanguageModelProvider::Google as i32,
466 request,
467 })
468 .await?;
469 Ok(response.token_count as usize)
470 }
471 .boxed()
472 }
473 CloudModel::Zed(_) => {
474 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
475 }
476 }
477 }
478
479 fn stream_completion(
480 &self,
481 request: LanguageModelRequest,
482 _cx: &AsyncAppContext,
483 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
484 match &self.model {
485 CloudModel::Anthropic(model) => {
486 let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
487 let client = self.client.clone();
488 let llm_api_token = self.llm_api_token.clone();
489 let future = self.request_limiter.stream(async move {
490 let response = Self::perform_llm_completion(
491 client.clone(),
492 llm_api_token,
493 PerformCompletionParams {
494 provider: client::LanguageModelProvider::Anthropic,
495 model: request.model.clone(),
496 provider_request: RawValue::from_string(serde_json::to_string(
497 &request,
498 )?)?,
499 },
500 )
501 .await?;
502 Ok(anthropic::extract_text_from_events(
503 response_lines(response).map_err(AnthropicError::Other),
504 ))
505 });
506 async move {
507 Ok(future
508 .await?
509 .map(|result| result.map_err(|err| anyhow!(err)))
510 .boxed())
511 }
512 .boxed()
513 }
514 CloudModel::OpenAi(model) => {
515 let client = self.client.clone();
516 let request = request.into_open_ai(model.id().into());
517 let llm_api_token = self.llm_api_token.clone();
518 let future = self.request_limiter.stream(async move {
519 let response = Self::perform_llm_completion(
520 client.clone(),
521 llm_api_token,
522 PerformCompletionParams {
523 provider: client::LanguageModelProvider::OpenAi,
524 model: request.model.clone(),
525 provider_request: RawValue::from_string(serde_json::to_string(
526 &request,
527 )?)?,
528 },
529 )
530 .await?;
531 Ok(open_ai::extract_text_from_events(response_lines(response)))
532 });
533 async move { Ok(future.await?.boxed()) }.boxed()
534 }
535 CloudModel::Google(model) => {
536 let client = self.client.clone();
537 let request = request.into_google(model.id().into());
538 let llm_api_token = self.llm_api_token.clone();
539 let future = self.request_limiter.stream(async move {
540 let response = Self::perform_llm_completion(
541 client.clone(),
542 llm_api_token,
543 PerformCompletionParams {
544 provider: client::LanguageModelProvider::Google,
545 model: request.model.clone(),
546 provider_request: RawValue::from_string(serde_json::to_string(
547 &request,
548 )?)?,
549 },
550 )
551 .await?;
552 Ok(google_ai::extract_text_from_events(response_lines(
553 response,
554 )))
555 });
556 async move { Ok(future.await?.boxed()) }.boxed()
557 }
558 CloudModel::Zed(model) => {
559 let client = self.client.clone();
560 let mut request = request.into_open_ai(model.id().into());
561 request.max_tokens = Some(4000);
562 let llm_api_token = self.llm_api_token.clone();
563 let future = self.request_limiter.stream(async move {
564 let response = Self::perform_llm_completion(
565 client.clone(),
566 llm_api_token,
567 PerformCompletionParams {
568 provider: client::LanguageModelProvider::Zed,
569 model: request.model.clone(),
570 provider_request: RawValue::from_string(serde_json::to_string(
571 &request,
572 )?)?,
573 },
574 )
575 .await?;
576 Ok(open_ai::extract_text_from_events(response_lines(response)))
577 });
578 async move { Ok(future.await?.boxed()) }.boxed()
579 }
580 }
581 }
582
583 fn use_any_tool(
584 &self,
585 request: LanguageModelRequest,
586 tool_name: String,
587 tool_description: String,
588 input_schema: serde_json::Value,
589 _cx: &AsyncAppContext,
590 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
591 let client = self.client.clone();
592 let llm_api_token = self.llm_api_token.clone();
593
594 match &self.model {
595 CloudModel::Anthropic(model) => {
596 let mut request =
597 request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
598 request.tool_choice = Some(anthropic::ToolChoice::Tool {
599 name: tool_name.clone(),
600 });
601 request.tools = vec![anthropic::Tool {
602 name: tool_name.clone(),
603 description: tool_description,
604 input_schema,
605 }];
606
607 self.request_limiter
608 .run(async move {
609 let response = Self::perform_llm_completion(
610 client.clone(),
611 llm_api_token,
612 PerformCompletionParams {
613 provider: client::LanguageModelProvider::Anthropic,
614 model: request.model.clone(),
615 provider_request: RawValue::from_string(serde_json::to_string(
616 &request,
617 )?)?,
618 },
619 )
620 .await?;
621
622 Ok(anthropic::extract_tool_args_from_events(
623 tool_name,
624 Box::pin(response_lines(response)),
625 )
626 .await?
627 .boxed())
628 })
629 .boxed()
630 }
631 CloudModel::OpenAi(model) => {
632 let mut request = request.into_open_ai(model.id().into());
633 request.tool_choice = Some(open_ai::ToolChoice::Other(
634 open_ai::ToolDefinition::Function {
635 function: open_ai::FunctionDefinition {
636 name: tool_name.clone(),
637 description: None,
638 parameters: None,
639 },
640 },
641 ));
642 request.tools = vec![open_ai::ToolDefinition::Function {
643 function: open_ai::FunctionDefinition {
644 name: tool_name.clone(),
645 description: Some(tool_description),
646 parameters: Some(input_schema),
647 },
648 }];
649
650 self.request_limiter
651 .run(async move {
652 let response = Self::perform_llm_completion(
653 client.clone(),
654 llm_api_token,
655 PerformCompletionParams {
656 provider: client::LanguageModelProvider::OpenAi,
657 model: request.model.clone(),
658 provider_request: RawValue::from_string(serde_json::to_string(
659 &request,
660 )?)?,
661 },
662 )
663 .await?;
664
665 Ok(open_ai::extract_tool_args_from_events(
666 tool_name,
667 Box::pin(response_lines(response)),
668 )
669 .await?
670 .boxed())
671 })
672 .boxed()
673 }
674 CloudModel::Google(_) => {
675 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
676 }
677 CloudModel::Zed(model) => {
678 // All Zed models are OpenAI-based at the time of writing.
679 let mut request = request.into_open_ai(model.id().into());
680 request.tool_choice = Some(open_ai::ToolChoice::Other(
681 open_ai::ToolDefinition::Function {
682 function: open_ai::FunctionDefinition {
683 name: tool_name.clone(),
684 description: None,
685 parameters: None,
686 },
687 },
688 ));
689 request.tools = vec![open_ai::ToolDefinition::Function {
690 function: open_ai::FunctionDefinition {
691 name: tool_name.clone(),
692 description: Some(tool_description),
693 parameters: Some(input_schema),
694 },
695 }];
696
697 self.request_limiter
698 .run(async move {
699 let response = Self::perform_llm_completion(
700 client.clone(),
701 llm_api_token,
702 PerformCompletionParams {
703 provider: client::LanguageModelProvider::Zed,
704 model: request.model.clone(),
705 provider_request: RawValue::from_string(serde_json::to_string(
706 &request,
707 )?)?,
708 },
709 )
710 .await?;
711
712 Ok(open_ai::extract_tool_args_from_events(
713 tool_name,
714 Box::pin(response_lines(response)),
715 )
716 .await?
717 .boxed())
718 })
719 .boxed()
720 }
721 }
722 }
723}
724
725fn response_lines<T: DeserializeOwned>(
726 response: Response<AsyncBody>,
727) -> impl Stream<Item = Result<T>> {
728 futures::stream::try_unfold(
729 (String::new(), BufReader::new(response.into_body())),
730 move |(mut line, mut body)| async {
731 match body.read_line(&mut line).await {
732 Ok(0) => Ok(None),
733 Ok(_) => {
734 let event: T = serde_json::from_str(&line)?;
735 line.clear();
736 Ok(Some((event, (line, body))))
737 }
738 Err(e) => Err(e.into()),
739 }
740 },
741 )
742}
743
744impl LlmApiToken {
745 async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
746 let lock = self.0.upgradable_read().await;
747 if let Some(token) = lock.as_ref() {
748 Ok(token.to_string())
749 } else {
750 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
751 }
752 }
753
754 async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
755 Self::fetch(self.0.write().await, &client).await
756 }
757
758 async fn fetch<'a>(
759 mut lock: RwLockWriteGuard<'a, Option<String>>,
760 client: &Arc<Client>,
761 ) -> Result<String> {
762 let response = client.request(proto::GetLlmToken {}).await?;
763 *lock = Some(response.token.clone());
764 Ok(response.token.clone())
765 }
766}
767
768struct ConfigurationView {
769 state: gpui::Model<State>,
770}
771
772impl ConfigurationView {
773 fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
774 self.state.update(cx, |state, cx| {
775 state.authenticate(cx).detach_and_log_err(cx);
776 });
777 cx.notify();
778 }
779}
780
781impl Render for ConfigurationView {
782 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
783 const ZED_AI_URL: &str = "https://zed.dev/ai";
784 const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
785
786 let is_connected = !self.state.read(cx).is_signed_out();
787 let plan = self.state.read(cx).user_store.read(cx).current_plan();
788 let must_accept_terms = !self.state.read(cx).has_accepted_terms_of_service(cx);
789
790 let is_pro = plan == Some(proto::Plan::ZedPro);
791
792 if is_connected {
793 v_flex()
794 .gap_3()
795 .max_w_4_5()
796 .when(must_accept_terms, |this| {
797 this.child(Label::new(
798 "You must accept the terms of service to use this provider.",
799 ))
800 })
801 .child(Label::new(
802 if is_pro {
803 "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
804 } else {
805 "You have basic access to models from Anthropic through the Zed AI Free plan."
806 }))
807 .children(if is_pro {
808 Some(
809 h_flex().child(
810 Button::new("manage_settings", "Manage Subscription")
811 .style(ButtonStyle::Filled)
812 .on_click(
813 cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
814 ),
815 ),
816 )
817 } else if cx.has_flag::<ZedPro>() {
818 Some(
819 h_flex()
820 .gap_2()
821 .child(
822 Button::new("learn_more", "Learn more")
823 .style(ButtonStyle::Subtle)
824 .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
825 )
826 .child(
827 Button::new("upgrade", "Upgrade")
828 .style(ButtonStyle::Subtle)
829 .color(Color::Accent)
830 .on_click(
831 cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
832 ),
833 ),
834 )
835 } else {
836 None
837 })
838 } else {
839 v_flex()
840 .gap_6()
841 .child(Label::new("Use the zed.dev to access language models."))
842 .child(
843 v_flex()
844 .gap_2()
845 .child(
846 Button::new("sign_in", "Sign in")
847 .icon_color(Color::Muted)
848 .icon(IconName::Github)
849 .icon_position(IconPosition::Start)
850 .style(ButtonStyle::Filled)
851 .full_width()
852 .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
853 )
854 .child(
855 div().flex().w_full().items_center().child(
856 Label::new("Sign in to enable collaboration.")
857 .color(Color::Muted)
858 .size(LabelSize::Small),
859 ),
860 ),
861 )
862 }
863 }
864}