1use super::open_ai::count_open_ai_tokens;
2use crate::provider::anthropic::map_to_language_model_completion_events;
3use crate::{
4 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
5 LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
6 LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
7};
8use anthropic::AnthropicError;
9use anyhow::{anyhow, Result};
10use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
11use collections::BTreeMap;
12use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
13use futures::{
14 future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
15 TryStreamExt as _,
16};
17use gpui::{
18 AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
19 Subscription, Task,
20};
21use http_client::{AsyncBody, HttpClient, Method, Response};
22use schemars::JsonSchema;
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use serde_json::value::RawValue;
25use settings::{Settings, SettingsStore};
26use smol::{
27 io::{AsyncReadExt, BufReader},
28 lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
29};
30use std::{
31 future,
32 sync::{Arc, LazyLock},
33};
34use strum::IntoEnumIterator;
35use ui::{prelude::*, TintColor};
36
37use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
38
39use super::anthropic::count_anthropic_tokens;
40
41pub const PROVIDER_ID: &str = "zed.dev";
42pub const PROVIDER_NAME: &str = "Zed";
43
44const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
45 option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
46
47fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
48 static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
49 ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
50 .map(|json| serde_json::from_str(json).unwrap())
51 .unwrap_or_default()
52 });
53 ADDITIONAL_MODELS.as_slice()
54}
55
56#[derive(Default, Clone, Debug, PartialEq)]
57pub struct ZedDotDevSettings {
58 pub available_models: Vec<AvailableModel>,
59}
60
61#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
62#[serde(rename_all = "lowercase")]
63pub enum AvailableProvider {
64 Anthropic,
65 OpenAi,
66 Google,
67}
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
70pub struct AvailableModel {
71 /// The provider of the language model.
72 pub provider: AvailableProvider,
73 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
74 pub name: String,
75 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
76 pub display_name: Option<String>,
77 /// The size of the context window, indicating the maximum number of tokens the model can process.
78 pub max_tokens: usize,
79 /// The maximum number of output tokens allowed by the model.
80 pub max_output_tokens: Option<u32>,
81 /// Override this model with a different Anthropic model for tool calls.
82 pub tool_override: Option<String>,
83 /// Indicates whether this custom model supports caching.
84 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
85}
86
87pub struct CloudLanguageModelProvider {
88 client: Arc<Client>,
89 llm_api_token: LlmApiToken,
90 state: gpui::Model<State>,
91 _maintain_client_status: Task<()>,
92}
93
94pub struct State {
95 client: Arc<Client>,
96 user_store: Model<UserStore>,
97 status: client::Status,
98 accept_terms: Option<Task<Result<()>>>,
99 _subscription: Subscription,
100}
101
102impl State {
103 fn is_signed_out(&self) -> bool {
104 self.status.is_signed_out()
105 }
106
107 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
108 let client = self.client.clone();
109 cx.spawn(move |this, mut cx| async move {
110 client.authenticate_and_connect(true, &cx).await?;
111 this.update(&mut cx, |_, cx| cx.notify())
112 })
113 }
114
115 fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
116 self.user_store
117 .read(cx)
118 .current_user_has_accepted_terms()
119 .unwrap_or(false)
120 }
121
122 fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
123 let user_store = self.user_store.clone();
124 self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
125 let _ = user_store
126 .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
127 .await;
128 this.update(&mut cx, |this, cx| {
129 this.accept_terms = None;
130 cx.notify()
131 })
132 }));
133 }
134}
135
136impl CloudLanguageModelProvider {
137 pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
138 let mut status_rx = client.status();
139 let status = *status_rx.borrow();
140
141 let state = cx.new_model(|cx| State {
142 client: client.clone(),
143 user_store,
144 status,
145 accept_terms: None,
146 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
147 cx.notify();
148 }),
149 });
150
151 let state_ref = state.downgrade();
152 let maintain_client_status = cx.spawn(|mut cx| async move {
153 while let Some(status) = status_rx.next().await {
154 if let Some(this) = state_ref.upgrade() {
155 _ = this.update(&mut cx, |this, cx| {
156 if this.status != status {
157 this.status = status;
158 cx.notify();
159 }
160 });
161 } else {
162 break;
163 }
164 }
165 });
166
167 Self {
168 client,
169 state,
170 llm_api_token: LlmApiToken::default(),
171 _maintain_client_status: maintain_client_status,
172 }
173 }
174}
175
176impl LanguageModelProviderState for CloudLanguageModelProvider {
177 type ObservableEntity = State;
178
179 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
180 Some(self.state.clone())
181 }
182}
183
184impl LanguageModelProvider for CloudLanguageModelProvider {
185 fn id(&self) -> LanguageModelProviderId {
186 LanguageModelProviderId(PROVIDER_ID.into())
187 }
188
189 fn name(&self) -> LanguageModelProviderName {
190 LanguageModelProviderName(PROVIDER_NAME.into())
191 }
192
193 fn icon(&self) -> IconName {
194 IconName::AiZed
195 }
196
197 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
198 let mut models = BTreeMap::default();
199
200 if cx.is_staff() {
201 for model in anthropic::Model::iter() {
202 if !matches!(model, anthropic::Model::Custom { .. }) {
203 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
204 }
205 }
206 for model in open_ai::Model::iter() {
207 if !matches!(model, open_ai::Model::Custom { .. }) {
208 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
209 }
210 }
211 for model in google_ai::Model::iter() {
212 if !matches!(model, google_ai::Model::Custom { .. }) {
213 models.insert(model.id().to_string(), CloudModel::Google(model));
214 }
215 }
216 for model in ZedModel::iter() {
217 models.insert(model.id().to_string(), CloudModel::Zed(model));
218 }
219 } else {
220 models.insert(
221 anthropic::Model::Claude3_5Sonnet.id().to_string(),
222 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
223 );
224 }
225
226 let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
227 zed_cloud_provider_additional_models()
228 } else {
229 &[]
230 };
231
232 // Override with available models from settings
233 for model in AllLanguageModelSettings::get_global(cx)
234 .zed_dot_dev
235 .available_models
236 .iter()
237 .chain(llm_closed_beta_models)
238 .cloned()
239 {
240 let model = match model.provider {
241 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
242 name: model.name.clone(),
243 display_name: model.display_name.clone(),
244 max_tokens: model.max_tokens,
245 tool_override: model.tool_override.clone(),
246 cache_configuration: model.cache_configuration.as_ref().map(|config| {
247 anthropic::AnthropicModelCacheConfiguration {
248 max_cache_anchors: config.max_cache_anchors,
249 should_speculate: config.should_speculate,
250 min_total_token: config.min_total_token,
251 }
252 }),
253 max_output_tokens: model.max_output_tokens,
254 }),
255 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
256 name: model.name.clone(),
257 max_tokens: model.max_tokens,
258 max_output_tokens: model.max_output_tokens,
259 }),
260 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
261 name: model.name.clone(),
262 max_tokens: model.max_tokens,
263 }),
264 };
265 models.insert(model.id().to_string(), model.clone());
266 }
267
268 models
269 .into_values()
270 .map(|model| {
271 Arc::new(CloudLanguageModel {
272 id: LanguageModelId::from(model.id().to_string()),
273 model,
274 llm_api_token: self.llm_api_token.clone(),
275 client: self.client.clone(),
276 request_limiter: RateLimiter::new(4),
277 }) as Arc<dyn LanguageModel>
278 })
279 .collect()
280 }
281
282 fn is_authenticated(&self, cx: &AppContext) -> bool {
283 !self.state.read(cx).is_signed_out()
284 }
285
286 fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
287 Task::ready(Ok(()))
288 }
289
290 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
291 cx.new_view(|_cx| ConfigurationView {
292 state: self.state.clone(),
293 })
294 .into()
295 }
296
297 fn must_accept_terms(&self, cx: &AppContext) -> bool {
298 !self.state.read(cx).has_accepted_terms_of_service(cx)
299 }
300
301 fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
302 let state = self.state.read(cx);
303
304 let terms = [(
305 "terms_of_service",
306 "Terms of Service",
307 "https://zed.dev/terms-of-service",
308 )]
309 .map(|(id, label, url)| {
310 Button::new(id, label)
311 .style(ButtonStyle::Subtle)
312 .icon(IconName::ExternalLink)
313 .icon_size(IconSize::XSmall)
314 .icon_color(Color::Muted)
315 .on_click(move |_, cx| cx.open_url(url))
316 });
317
318 if state.has_accepted_terms_of_service(cx) {
319 None
320 } else {
321 let disabled = state.accept_terms.is_some();
322 Some(
323 v_flex()
324 .gap_2()
325 .child(
326 v_flex()
327 .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
328 .child(
329 Label::new(
330 "Please read and accept our terms and conditions to continue.",
331 )
332 .size(LabelSize::Small),
333 ),
334 )
335 .child(v_flex().gap_1().children(terms))
336 .child(
337 h_flex().justify_end().child(
338 Button::new("accept_terms", "I've read it and accept it")
339 .disabled(disabled)
340 .on_click({
341 let state = self.state.downgrade();
342 move |_, cx| {
343 state
344 .update(cx, |state, cx| {
345 state.accept_terms_of_service(cx)
346 })
347 .ok();
348 }
349 }),
350 ),
351 )
352 .into_any(),
353 )
354 }
355 }
356
357 fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
358 Task::ready(Ok(()))
359 }
360}
361
362pub struct CloudLanguageModel {
363 id: LanguageModelId,
364 model: CloudModel,
365 llm_api_token: LlmApiToken,
366 client: Arc<Client>,
367 request_limiter: RateLimiter,
368}
369
370#[derive(Clone, Default)]
371struct LlmApiToken(Arc<RwLock<Option<String>>>);
372
373impl CloudLanguageModel {
374 async fn perform_llm_completion(
375 client: Arc<Client>,
376 llm_api_token: LlmApiToken,
377 body: PerformCompletionParams,
378 ) -> Result<Response<AsyncBody>> {
379 let http_client = &client.http_client();
380
381 let mut token = llm_api_token.acquire(&client).await?;
382 let mut did_retry = false;
383
384 let response = loop {
385 let request = http_client::Request::builder()
386 .method(Method::POST)
387 .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
388 .header("Content-Type", "application/json")
389 .header("Authorization", format!("Bearer {token}"))
390 .body(serde_json::to_string(&body)?.into())?;
391 let mut response = http_client.send(request).await?;
392 if response.status().is_success() {
393 break response;
394 } else if !did_retry
395 && response
396 .headers()
397 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
398 .is_some()
399 {
400 did_retry = true;
401 token = llm_api_token.refresh(&client).await?;
402 } else {
403 let mut body = String::new();
404 response.body_mut().read_to_string(&mut body).await?;
405 break Err(anyhow!(
406 "cloud language model completion failed with status {}: {body}",
407 response.status()
408 ))?;
409 }
410 };
411
412 Ok(response)
413 }
414}
415
416impl LanguageModel for CloudLanguageModel {
417 fn id(&self) -> LanguageModelId {
418 self.id.clone()
419 }
420
421 fn name(&self) -> LanguageModelName {
422 LanguageModelName::from(self.model.display_name().to_string())
423 }
424
425 fn icon(&self) -> Option<IconName> {
426 self.model.icon()
427 }
428
429 fn provider_id(&self) -> LanguageModelProviderId {
430 LanguageModelProviderId(PROVIDER_ID.into())
431 }
432
433 fn provider_name(&self) -> LanguageModelProviderName {
434 LanguageModelProviderName(PROVIDER_NAME.into())
435 }
436
437 fn telemetry_id(&self) -> String {
438 format!("zed.dev/{}", self.model.id())
439 }
440
441 fn availability(&self) -> LanguageModelAvailability {
442 self.model.availability()
443 }
444
445 fn max_token_count(&self) -> usize {
446 self.model.max_token_count()
447 }
448
449 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
450 match &self.model {
451 CloudModel::Anthropic(model) => {
452 model
453 .cache_configuration()
454 .map(|cache| LanguageModelCacheConfiguration {
455 max_cache_anchors: cache.max_cache_anchors,
456 should_speculate: cache.should_speculate,
457 min_total_token: cache.min_total_token,
458 })
459 }
460 CloudModel::OpenAi(_) | CloudModel::Google(_) | CloudModel::Zed(_) => None,
461 }
462 }
463
464 fn count_tokens(
465 &self,
466 request: LanguageModelRequest,
467 cx: &AppContext,
468 ) -> BoxFuture<'static, Result<usize>> {
469 match self.model.clone() {
470 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
471 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
472 CloudModel::Google(model) => {
473 let client = self.client.clone();
474 let request = request.into_google(model.id().into());
475 let request = google_ai::CountTokensRequest {
476 contents: request.contents,
477 };
478 async move {
479 let request = serde_json::to_string(&request)?;
480 let response = client
481 .request(proto::CountLanguageModelTokens {
482 provider: proto::LanguageModelProvider::Google as i32,
483 request,
484 })
485 .await?;
486 Ok(response.token_count as usize)
487 }
488 .boxed()
489 }
490 CloudModel::Zed(_) => {
491 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
492 }
493 }
494 }
495
496 fn stream_completion(
497 &self,
498 request: LanguageModelRequest,
499 _cx: &AsyncAppContext,
500 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
501 match &self.model {
502 CloudModel::Anthropic(model) => {
503 let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
504 let client = self.client.clone();
505 let llm_api_token = self.llm_api_token.clone();
506 let future = self.request_limiter.stream(async move {
507 let response = Self::perform_llm_completion(
508 client.clone(),
509 llm_api_token,
510 PerformCompletionParams {
511 provider: client::LanguageModelProvider::Anthropic,
512 model: request.model.clone(),
513 provider_request: RawValue::from_string(serde_json::to_string(
514 &request,
515 )?)?,
516 },
517 )
518 .await?;
519 Ok(map_to_language_model_completion_events(Box::pin(
520 response_lines(response).map_err(AnthropicError::Other),
521 )))
522 });
523 async move { Ok(future.await?.boxed()) }.boxed()
524 }
525 CloudModel::OpenAi(model) => {
526 let client = self.client.clone();
527 let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
528 let llm_api_token = self.llm_api_token.clone();
529 let future = self.request_limiter.stream(async move {
530 let response = Self::perform_llm_completion(
531 client.clone(),
532 llm_api_token,
533 PerformCompletionParams {
534 provider: client::LanguageModelProvider::OpenAi,
535 model: request.model.clone(),
536 provider_request: RawValue::from_string(serde_json::to_string(
537 &request,
538 )?)?,
539 },
540 )
541 .await?;
542 Ok(open_ai::extract_text_from_events(response_lines(response)))
543 });
544 async move {
545 Ok(future
546 .await?
547 .map(|result| result.map(LanguageModelCompletionEvent::Text))
548 .boxed())
549 }
550 .boxed()
551 }
552 CloudModel::Google(model) => {
553 let client = self.client.clone();
554 let request = request.into_google(model.id().into());
555 let llm_api_token = self.llm_api_token.clone();
556 let future = self.request_limiter.stream(async move {
557 let response = Self::perform_llm_completion(
558 client.clone(),
559 llm_api_token,
560 PerformCompletionParams {
561 provider: client::LanguageModelProvider::Google,
562 model: request.model.clone(),
563 provider_request: RawValue::from_string(serde_json::to_string(
564 &request,
565 )?)?,
566 },
567 )
568 .await?;
569 Ok(google_ai::extract_text_from_events(response_lines(
570 response,
571 )))
572 });
573 async move {
574 Ok(future
575 .await?
576 .map(|result| result.map(LanguageModelCompletionEvent::Text))
577 .boxed())
578 }
579 .boxed()
580 }
581 CloudModel::Zed(model) => {
582 let client = self.client.clone();
583 let mut request = request.into_open_ai(model.id().into(), None);
584 request.max_tokens = Some(4000);
585 let llm_api_token = self.llm_api_token.clone();
586 let future = self.request_limiter.stream(async move {
587 let response = Self::perform_llm_completion(
588 client.clone(),
589 llm_api_token,
590 PerformCompletionParams {
591 provider: client::LanguageModelProvider::Zed,
592 model: request.model.clone(),
593 provider_request: RawValue::from_string(serde_json::to_string(
594 &request,
595 )?)?,
596 },
597 )
598 .await?;
599 Ok(open_ai::extract_text_from_events(response_lines(response)))
600 });
601 async move {
602 Ok(future
603 .await?
604 .map(|result| result.map(LanguageModelCompletionEvent::Text))
605 .boxed())
606 }
607 .boxed()
608 }
609 }
610 }
611
612 fn use_any_tool(
613 &self,
614 request: LanguageModelRequest,
615 tool_name: String,
616 tool_description: String,
617 input_schema: serde_json::Value,
618 _cx: &AsyncAppContext,
619 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
620 let client = self.client.clone();
621 let llm_api_token = self.llm_api_token.clone();
622
623 match &self.model {
624 CloudModel::Anthropic(model) => {
625 let mut request =
626 request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
627 request.tool_choice = Some(anthropic::ToolChoice::Tool {
628 name: tool_name.clone(),
629 });
630 request.tools = vec![anthropic::Tool {
631 name: tool_name.clone(),
632 description: tool_description,
633 input_schema,
634 }];
635
636 self.request_limiter
637 .run(async move {
638 let response = Self::perform_llm_completion(
639 client.clone(),
640 llm_api_token,
641 PerformCompletionParams {
642 provider: client::LanguageModelProvider::Anthropic,
643 model: request.model.clone(),
644 provider_request: RawValue::from_string(serde_json::to_string(
645 &request,
646 )?)?,
647 },
648 )
649 .await?;
650
651 Ok(anthropic::extract_tool_args_from_events(
652 tool_name,
653 Box::pin(response_lines(response)),
654 )
655 .await?
656 .boxed())
657 })
658 .boxed()
659 }
660 CloudModel::OpenAi(model) => {
661 let mut request =
662 request.into_open_ai(model.id().into(), model.max_output_tokens());
663 request.tool_choice = Some(open_ai::ToolChoice::Other(
664 open_ai::ToolDefinition::Function {
665 function: open_ai::FunctionDefinition {
666 name: tool_name.clone(),
667 description: None,
668 parameters: None,
669 },
670 },
671 ));
672 request.tools = vec![open_ai::ToolDefinition::Function {
673 function: open_ai::FunctionDefinition {
674 name: tool_name.clone(),
675 description: Some(tool_description),
676 parameters: Some(input_schema),
677 },
678 }];
679
680 self.request_limiter
681 .run(async move {
682 let response = Self::perform_llm_completion(
683 client.clone(),
684 llm_api_token,
685 PerformCompletionParams {
686 provider: client::LanguageModelProvider::OpenAi,
687 model: request.model.clone(),
688 provider_request: RawValue::from_string(serde_json::to_string(
689 &request,
690 )?)?,
691 },
692 )
693 .await?;
694
695 Ok(open_ai::extract_tool_args_from_events(
696 tool_name,
697 Box::pin(response_lines(response)),
698 )
699 .await?
700 .boxed())
701 })
702 .boxed()
703 }
704 CloudModel::Google(_) => {
705 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
706 }
707 CloudModel::Zed(model) => {
708 // All Zed models are OpenAI-based at the time of writing.
709 let mut request = request.into_open_ai(model.id().into(), None);
710 request.tool_choice = Some(open_ai::ToolChoice::Other(
711 open_ai::ToolDefinition::Function {
712 function: open_ai::FunctionDefinition {
713 name: tool_name.clone(),
714 description: None,
715 parameters: None,
716 },
717 },
718 ));
719 request.tools = vec![open_ai::ToolDefinition::Function {
720 function: open_ai::FunctionDefinition {
721 name: tool_name.clone(),
722 description: Some(tool_description),
723 parameters: Some(input_schema),
724 },
725 }];
726
727 self.request_limiter
728 .run(async move {
729 let response = Self::perform_llm_completion(
730 client.clone(),
731 llm_api_token,
732 PerformCompletionParams {
733 provider: client::LanguageModelProvider::Zed,
734 model: request.model.clone(),
735 provider_request: RawValue::from_string(serde_json::to_string(
736 &request,
737 )?)?,
738 },
739 )
740 .await?;
741
742 Ok(open_ai::extract_tool_args_from_events(
743 tool_name,
744 Box::pin(response_lines(response)),
745 )
746 .await?
747 .boxed())
748 })
749 .boxed()
750 }
751 }
752 }
753}
754
755fn response_lines<T: DeserializeOwned>(
756 response: Response<AsyncBody>,
757) -> impl Stream<Item = Result<T>> {
758 futures::stream::try_unfold(
759 (String::new(), BufReader::new(response.into_body())),
760 move |(mut line, mut body)| async {
761 match body.read_line(&mut line).await {
762 Ok(0) => Ok(None),
763 Ok(_) => {
764 let event: T = serde_json::from_str(&line)?;
765 line.clear();
766 Ok(Some((event, (line, body))))
767 }
768 Err(e) => Err(e.into()),
769 }
770 },
771 )
772}
773
774impl LlmApiToken {
775 async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
776 let lock = self.0.upgradable_read().await;
777 if let Some(token) = lock.as_ref() {
778 Ok(token.to_string())
779 } else {
780 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
781 }
782 }
783
784 async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
785 Self::fetch(self.0.write().await, client).await
786 }
787
788 async fn fetch<'a>(
789 mut lock: RwLockWriteGuard<'a, Option<String>>,
790 client: &Arc<Client>,
791 ) -> Result<String> {
792 let response = client.request(proto::GetLlmToken {}).await?;
793 *lock = Some(response.token.clone());
794 Ok(response.token.clone())
795 }
796}
797
798struct ConfigurationView {
799 state: gpui::Model<State>,
800}
801
802impl ConfigurationView {
803 fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
804 self.state.update(cx, |state, cx| {
805 state.authenticate(cx).detach_and_log_err(cx);
806 });
807 cx.notify();
808 }
809
810 fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
811 if self.state.read(cx).has_accepted_terms_of_service(cx) {
812 return None;
813 }
814
815 let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
816
817 let terms_button = Button::new("terms_of_service", "Terms of Service")
818 .style(ButtonStyle::Subtle)
819 .icon(IconName::ExternalLink)
820 .icon_color(Color::Muted)
821 .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
822
823 let text =
824 "In order to use Zed AI, please read and accept our terms and conditions to continue:";
825
826 let form = v_flex()
827 .gap_2()
828 .child(Label::new("Terms and Conditions"))
829 .child(Label::new(text))
830 .child(h_flex().justify_center().child(terms_button))
831 .child(
832 h_flex().justify_center().child(
833 Button::new("accept_terms", "I've read and accept the terms of service")
834 .style(ButtonStyle::Tinted(TintColor::Accent))
835 .disabled(accept_terms_disabled)
836 .on_click({
837 let state = self.state.downgrade();
838 move |_, cx| {
839 state
840 .update(cx, |state, cx| state.accept_terms_of_service(cx))
841 .ok();
842 }
843 }),
844 ),
845 );
846
847 Some(form.into_any())
848 }
849}
850
851impl Render for ConfigurationView {
852 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
853 const ZED_AI_URL: &str = "https://zed.dev/ai";
854 const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
855
856 let is_connected = !self.state.read(cx).is_signed_out();
857 let plan = self.state.read(cx).user_store.read(cx).current_plan();
858 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
859
860 let is_pro = plan == Some(proto::Plan::ZedPro);
861 let subscription_text = Label::new(if is_pro {
862 "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
863 } else {
864 "You have basic access to models from Anthropic through the Zed AI Free plan."
865 });
866 let manage_subscription_button = if is_pro {
867 Some(
868 h_flex().child(
869 Button::new("manage_settings", "Manage Subscription")
870 .style(ButtonStyle::Tinted(TintColor::Accent))
871 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
872 ),
873 )
874 } else if cx.has_flag::<ZedPro>() {
875 Some(
876 h_flex()
877 .gap_2()
878 .child(
879 Button::new("learn_more", "Learn more")
880 .style(ButtonStyle::Subtle)
881 .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
882 )
883 .child(
884 Button::new("upgrade", "Upgrade")
885 .style(ButtonStyle::Subtle)
886 .color(Color::Accent)
887 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
888 ),
889 )
890 } else {
891 None
892 };
893
894 if is_connected {
895 v_flex()
896 .gap_3()
897 .max_w_4_5()
898 .children(self.render_accept_terms(cx))
899 .when(has_accepted_terms, |this| {
900 this.child(subscription_text)
901 .children(manage_subscription_button)
902 })
903 } else {
904 v_flex()
905 .gap_6()
906 .child(Label::new("Use the zed.dev to access language models."))
907 .child(
908 v_flex()
909 .gap_2()
910 .child(
911 Button::new("sign_in", "Sign in")
912 .icon_color(Color::Muted)
913 .icon(IconName::Github)
914 .icon_position(IconPosition::Start)
915 .style(ButtonStyle::Filled)
916 .full_width()
917 .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
918 )
919 .child(
920 div().flex().w_full().items_center().child(
921 Label::new("Sign in to enable collaboration.")
922 .color(Color::Muted)
923 .size(LabelSize::Small),
924 ),
925 ),
926 )
927 }
928 }
929}