1use super::open_ai::count_open_ai_tokens;
2use crate::provider::anthropic::map_to_language_model_completion_events;
3use crate::{
4 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
5 LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
6 LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
7};
8use anthropic::AnthropicError;
9use anyhow::{anyhow, Result};
10use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
11use collections::BTreeMap;
12use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
13use futures::{
14 future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
15 TryStreamExt as _,
16};
17use gpui::{
18 AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
19 Subscription, Task,
20};
21use http_client::{AsyncBody, HttpClient, Method, Response};
22use schemars::JsonSchema;
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use serde_json::value::RawValue;
25use settings::{Settings, SettingsStore};
26use smol::{
27 io::{AsyncReadExt, BufReader},
28 lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
29};
30use std::{
31 future,
32 sync::{Arc, LazyLock},
33};
34use strum::IntoEnumIterator;
35use ui::{prelude::*, TintColor};
36
37use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
38
39use super::anthropic::count_anthropic_tokens;
40
41pub const PROVIDER_ID: &str = "zed.dev";
42pub const PROVIDER_NAME: &str = "Zed";
43
44const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
45 option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
46
47fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
48 static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
49 ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
50 .map(|json| serde_json::from_str(json).unwrap())
51 .unwrap_or_default()
52 });
53 ADDITIONAL_MODELS.as_slice()
54}
55
56#[derive(Default, Clone, Debug, PartialEq)]
57pub struct ZedDotDevSettings {
58 pub available_models: Vec<AvailableModel>,
59}
60
61#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
62#[serde(rename_all = "lowercase")]
63pub enum AvailableProvider {
64 Anthropic,
65 OpenAi,
66 Google,
67}
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
70pub struct AvailableModel {
71 /// The provider of the language model.
72 pub provider: AvailableProvider,
73 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
74 pub name: String,
75 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
76 pub display_name: Option<String>,
77 /// The size of the context window, indicating the maximum number of tokens the model can process.
78 pub max_tokens: usize,
79 /// The maximum number of output tokens allowed by the model.
80 pub max_output_tokens: Option<u32>,
81 /// Override this model with a different Anthropic model for tool calls.
82 pub tool_override: Option<String>,
83 /// Indicates whether this custom model supports caching.
84 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
85}
86
87pub struct CloudLanguageModelProvider {
88 client: Arc<Client>,
89 llm_api_token: LlmApiToken,
90 state: gpui::Model<State>,
91 _maintain_client_status: Task<()>,
92}
93
94pub struct State {
95 client: Arc<Client>,
96 user_store: Model<UserStore>,
97 status: client::Status,
98 accept_terms: Option<Task<Result<()>>>,
99 _subscription: Subscription,
100}
101
102impl State {
103 fn is_signed_out(&self) -> bool {
104 self.status.is_signed_out()
105 }
106
107 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
108 let client = self.client.clone();
109 cx.spawn(move |this, mut cx| async move {
110 client.authenticate_and_connect(true, &cx).await?;
111 this.update(&mut cx, |_, cx| cx.notify())
112 })
113 }
114
115 fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
116 self.user_store
117 .read(cx)
118 .current_user_has_accepted_terms()
119 .unwrap_or(false)
120 }
121
122 fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
123 let user_store = self.user_store.clone();
124 self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
125 let _ = user_store
126 .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
127 .await;
128 this.update(&mut cx, |this, cx| {
129 this.accept_terms = None;
130 cx.notify()
131 })
132 }));
133 }
134}
135
136impl CloudLanguageModelProvider {
137 pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
138 let mut status_rx = client.status();
139 let status = *status_rx.borrow();
140
141 let state = cx.new_model(|cx| State {
142 client: client.clone(),
143 user_store,
144 status,
145 accept_terms: None,
146 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
147 cx.notify();
148 }),
149 });
150
151 let state_ref = state.downgrade();
152 let maintain_client_status = cx.spawn(|mut cx| async move {
153 while let Some(status) = status_rx.next().await {
154 if let Some(this) = state_ref.upgrade() {
155 _ = this.update(&mut cx, |this, cx| {
156 if this.status != status {
157 this.status = status;
158 cx.notify();
159 }
160 });
161 } else {
162 break;
163 }
164 }
165 });
166
167 Self {
168 client,
169 state,
170 llm_api_token: LlmApiToken::default(),
171 _maintain_client_status: maintain_client_status,
172 }
173 }
174}
175
176impl LanguageModelProviderState for CloudLanguageModelProvider {
177 type ObservableEntity = State;
178
179 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
180 Some(self.state.clone())
181 }
182}
183
184impl LanguageModelProvider for CloudLanguageModelProvider {
185 fn id(&self) -> LanguageModelProviderId {
186 LanguageModelProviderId(PROVIDER_ID.into())
187 }
188
189 fn name(&self) -> LanguageModelProviderName {
190 LanguageModelProviderName(PROVIDER_NAME.into())
191 }
192
193 fn icon(&self) -> IconName {
194 IconName::AiZed
195 }
196
197 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
198 let mut models = BTreeMap::default();
199
200 if cx.is_staff() {
201 for model in anthropic::Model::iter() {
202 if !matches!(model, anthropic::Model::Custom { .. }) {
203 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
204 }
205 }
206 for model in open_ai::Model::iter() {
207 if !matches!(model, open_ai::Model::Custom { .. }) {
208 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
209 }
210 }
211 for model in google_ai::Model::iter() {
212 if !matches!(model, google_ai::Model::Custom { .. }) {
213 models.insert(model.id().to_string(), CloudModel::Google(model));
214 }
215 }
216 for model in ZedModel::iter() {
217 models.insert(model.id().to_string(), CloudModel::Zed(model));
218 }
219 } else {
220 models.insert(
221 anthropic::Model::Claude3_5Sonnet.id().to_string(),
222 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
223 );
224 }
225
226 let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
227 zed_cloud_provider_additional_models()
228 } else {
229 &[]
230 };
231
232 // Override with available models from settings
233 for model in AllLanguageModelSettings::get_global(cx)
234 .zed_dot_dev
235 .available_models
236 .iter()
237 .chain(llm_closed_beta_models)
238 .cloned()
239 {
240 let model = match model.provider {
241 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
242 name: model.name.clone(),
243 display_name: model.display_name.clone(),
244 max_tokens: model.max_tokens,
245 tool_override: model.tool_override.clone(),
246 cache_configuration: model.cache_configuration.as_ref().map(|config| {
247 anthropic::AnthropicModelCacheConfiguration {
248 max_cache_anchors: config.max_cache_anchors,
249 should_speculate: config.should_speculate,
250 min_total_token: config.min_total_token,
251 }
252 }),
253 max_output_tokens: model.max_output_tokens,
254 }),
255 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
256 name: model.name.clone(),
257 display_name: model.display_name.clone(),
258 max_tokens: model.max_tokens,
259 max_output_tokens: model.max_output_tokens,
260 }),
261 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
262 name: model.name.clone(),
263 display_name: model.display_name.clone(),
264 max_tokens: model.max_tokens,
265 }),
266 };
267 models.insert(model.id().to_string(), model.clone());
268 }
269
270 models
271 .into_values()
272 .map(|model| {
273 Arc::new(CloudLanguageModel {
274 id: LanguageModelId::from(model.id().to_string()),
275 model,
276 llm_api_token: self.llm_api_token.clone(),
277 client: self.client.clone(),
278 request_limiter: RateLimiter::new(4),
279 }) as Arc<dyn LanguageModel>
280 })
281 .collect()
282 }
283
284 fn is_authenticated(&self, cx: &AppContext) -> bool {
285 !self.state.read(cx).is_signed_out()
286 }
287
288 fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
289 Task::ready(Ok(()))
290 }
291
292 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
293 cx.new_view(|_cx| ConfigurationView {
294 state: self.state.clone(),
295 })
296 .into()
297 }
298
299 fn must_accept_terms(&self, cx: &AppContext) -> bool {
300 !self.state.read(cx).has_accepted_terms_of_service(cx)
301 }
302
303 fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
304 let state = self.state.read(cx);
305
306 let terms = [(
307 "terms_of_service",
308 "Terms of Service",
309 "https://zed.dev/terms-of-service",
310 )]
311 .map(|(id, label, url)| {
312 Button::new(id, label)
313 .style(ButtonStyle::Subtle)
314 .icon(IconName::ExternalLink)
315 .icon_size(IconSize::XSmall)
316 .icon_color(Color::Muted)
317 .on_click(move |_, cx| cx.open_url(url))
318 });
319
320 if state.has_accepted_terms_of_service(cx) {
321 None
322 } else {
323 let disabled = state.accept_terms.is_some();
324 Some(
325 v_flex()
326 .gap_2()
327 .child(
328 v_flex()
329 .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
330 .child(
331 Label::new(
332 "Please read and accept our terms and conditions to continue.",
333 )
334 .size(LabelSize::Small),
335 ),
336 )
337 .child(v_flex().gap_1().children(terms))
338 .child(
339 h_flex().justify_end().child(
340 Button::new("accept_terms", "I've read it and accept it")
341 .disabled(disabled)
342 .on_click({
343 let state = self.state.downgrade();
344 move |_, cx| {
345 state
346 .update(cx, |state, cx| {
347 state.accept_terms_of_service(cx)
348 })
349 .ok();
350 }
351 }),
352 ),
353 )
354 .into_any(),
355 )
356 }
357 }
358
359 fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
360 Task::ready(Ok(()))
361 }
362}
363
364pub struct CloudLanguageModel {
365 id: LanguageModelId,
366 model: CloudModel,
367 llm_api_token: LlmApiToken,
368 client: Arc<Client>,
369 request_limiter: RateLimiter,
370}
371
372#[derive(Clone, Default)]
373struct LlmApiToken(Arc<RwLock<Option<String>>>);
374
375impl CloudLanguageModel {
376 async fn perform_llm_completion(
377 client: Arc<Client>,
378 llm_api_token: LlmApiToken,
379 body: PerformCompletionParams,
380 ) -> Result<Response<AsyncBody>> {
381 let http_client = &client.http_client();
382
383 let mut token = llm_api_token.acquire(&client).await?;
384 let mut did_retry = false;
385
386 let response = loop {
387 let request = http_client::Request::builder()
388 .method(Method::POST)
389 .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
390 .header("Content-Type", "application/json")
391 .header("Authorization", format!("Bearer {token}"))
392 .body(serde_json::to_string(&body)?.into())?;
393 let mut response = http_client.send(request).await?;
394 if response.status().is_success() {
395 break response;
396 } else if !did_retry
397 && response
398 .headers()
399 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
400 .is_some()
401 {
402 did_retry = true;
403 token = llm_api_token.refresh(&client).await?;
404 } else {
405 let mut body = String::new();
406 response.body_mut().read_to_string(&mut body).await?;
407 break Err(anyhow!(
408 "cloud language model completion failed with status {}: {body}",
409 response.status()
410 ))?;
411 }
412 };
413
414 Ok(response)
415 }
416}
417
418impl LanguageModel for CloudLanguageModel {
419 fn id(&self) -> LanguageModelId {
420 self.id.clone()
421 }
422
423 fn name(&self) -> LanguageModelName {
424 LanguageModelName::from(self.model.display_name().to_string())
425 }
426
427 fn icon(&self) -> Option<IconName> {
428 self.model.icon()
429 }
430
431 fn provider_id(&self) -> LanguageModelProviderId {
432 LanguageModelProviderId(PROVIDER_ID.into())
433 }
434
435 fn provider_name(&self) -> LanguageModelProviderName {
436 LanguageModelProviderName(PROVIDER_NAME.into())
437 }
438
439 fn telemetry_id(&self) -> String {
440 format!("zed.dev/{}", self.model.id())
441 }
442
443 fn availability(&self) -> LanguageModelAvailability {
444 self.model.availability()
445 }
446
447 fn max_token_count(&self) -> usize {
448 self.model.max_token_count()
449 }
450
451 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
452 match &self.model {
453 CloudModel::Anthropic(model) => {
454 model
455 .cache_configuration()
456 .map(|cache| LanguageModelCacheConfiguration {
457 max_cache_anchors: cache.max_cache_anchors,
458 should_speculate: cache.should_speculate,
459 min_total_token: cache.min_total_token,
460 })
461 }
462 CloudModel::OpenAi(_) | CloudModel::Google(_) | CloudModel::Zed(_) => None,
463 }
464 }
465
466 fn count_tokens(
467 &self,
468 request: LanguageModelRequest,
469 cx: &AppContext,
470 ) -> BoxFuture<'static, Result<usize>> {
471 match self.model.clone() {
472 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
473 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
474 CloudModel::Google(model) => {
475 let client = self.client.clone();
476 let request = request.into_google(model.id().into());
477 let request = google_ai::CountTokensRequest {
478 contents: request.contents,
479 };
480 async move {
481 let request = serde_json::to_string(&request)?;
482 let response = client
483 .request(proto::CountLanguageModelTokens {
484 provider: proto::LanguageModelProvider::Google as i32,
485 request,
486 })
487 .await?;
488 Ok(response.token_count as usize)
489 }
490 .boxed()
491 }
492 CloudModel::Zed(_) => {
493 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
494 }
495 }
496 }
497
498 fn stream_completion(
499 &self,
500 request: LanguageModelRequest,
501 _cx: &AsyncAppContext,
502 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
503 match &self.model {
504 CloudModel::Anthropic(model) => {
505 let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
506 let client = self.client.clone();
507 let llm_api_token = self.llm_api_token.clone();
508 let future = self.request_limiter.stream(async move {
509 let response = Self::perform_llm_completion(
510 client.clone(),
511 llm_api_token,
512 PerformCompletionParams {
513 provider: client::LanguageModelProvider::Anthropic,
514 model: request.model.clone(),
515 provider_request: RawValue::from_string(serde_json::to_string(
516 &request,
517 )?)?,
518 },
519 )
520 .await?;
521 Ok(map_to_language_model_completion_events(Box::pin(
522 response_lines(response).map_err(AnthropicError::Other),
523 )))
524 });
525 async move { Ok(future.await?.boxed()) }.boxed()
526 }
527 CloudModel::OpenAi(model) => {
528 let client = self.client.clone();
529 let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
530 let llm_api_token = self.llm_api_token.clone();
531 let future = self.request_limiter.stream(async move {
532 let response = Self::perform_llm_completion(
533 client.clone(),
534 llm_api_token,
535 PerformCompletionParams {
536 provider: client::LanguageModelProvider::OpenAi,
537 model: request.model.clone(),
538 provider_request: RawValue::from_string(serde_json::to_string(
539 &request,
540 )?)?,
541 },
542 )
543 .await?;
544 Ok(open_ai::extract_text_from_events(response_lines(response)))
545 });
546 async move {
547 Ok(future
548 .await?
549 .map(|result| result.map(LanguageModelCompletionEvent::Text))
550 .boxed())
551 }
552 .boxed()
553 }
554 CloudModel::Google(model) => {
555 let client = self.client.clone();
556 let request = request.into_google(model.id().into());
557 let llm_api_token = self.llm_api_token.clone();
558 let future = self.request_limiter.stream(async move {
559 let response = Self::perform_llm_completion(
560 client.clone(),
561 llm_api_token,
562 PerformCompletionParams {
563 provider: client::LanguageModelProvider::Google,
564 model: request.model.clone(),
565 provider_request: RawValue::from_string(serde_json::to_string(
566 &request,
567 )?)?,
568 },
569 )
570 .await?;
571 Ok(google_ai::extract_text_from_events(response_lines(
572 response,
573 )))
574 });
575 async move {
576 Ok(future
577 .await?
578 .map(|result| result.map(LanguageModelCompletionEvent::Text))
579 .boxed())
580 }
581 .boxed()
582 }
583 CloudModel::Zed(model) => {
584 let client = self.client.clone();
585 let mut request = request.into_open_ai(model.id().into(), None);
586 request.max_tokens = Some(4000);
587 let llm_api_token = self.llm_api_token.clone();
588 let future = self.request_limiter.stream(async move {
589 let response = Self::perform_llm_completion(
590 client.clone(),
591 llm_api_token,
592 PerformCompletionParams {
593 provider: client::LanguageModelProvider::Zed,
594 model: request.model.clone(),
595 provider_request: RawValue::from_string(serde_json::to_string(
596 &request,
597 )?)?,
598 },
599 )
600 .await?;
601 Ok(open_ai::extract_text_from_events(response_lines(response)))
602 });
603 async move {
604 Ok(future
605 .await?
606 .map(|result| result.map(LanguageModelCompletionEvent::Text))
607 .boxed())
608 }
609 .boxed()
610 }
611 }
612 }
613
614 fn use_any_tool(
615 &self,
616 request: LanguageModelRequest,
617 tool_name: String,
618 tool_description: String,
619 input_schema: serde_json::Value,
620 _cx: &AsyncAppContext,
621 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
622 let client = self.client.clone();
623 let llm_api_token = self.llm_api_token.clone();
624
625 match &self.model {
626 CloudModel::Anthropic(model) => {
627 let mut request =
628 request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
629 request.tool_choice = Some(anthropic::ToolChoice::Tool {
630 name: tool_name.clone(),
631 });
632 request.tools = vec![anthropic::Tool {
633 name: tool_name.clone(),
634 description: tool_description,
635 input_schema,
636 }];
637
638 self.request_limiter
639 .run(async move {
640 let response = Self::perform_llm_completion(
641 client.clone(),
642 llm_api_token,
643 PerformCompletionParams {
644 provider: client::LanguageModelProvider::Anthropic,
645 model: request.model.clone(),
646 provider_request: RawValue::from_string(serde_json::to_string(
647 &request,
648 )?)?,
649 },
650 )
651 .await?;
652
653 Ok(anthropic::extract_tool_args_from_events(
654 tool_name,
655 Box::pin(response_lines(response)),
656 )
657 .await?
658 .boxed())
659 })
660 .boxed()
661 }
662 CloudModel::OpenAi(model) => {
663 let mut request =
664 request.into_open_ai(model.id().into(), model.max_output_tokens());
665 request.tool_choice = Some(open_ai::ToolChoice::Other(
666 open_ai::ToolDefinition::Function {
667 function: open_ai::FunctionDefinition {
668 name: tool_name.clone(),
669 description: None,
670 parameters: None,
671 },
672 },
673 ));
674 request.tools = vec![open_ai::ToolDefinition::Function {
675 function: open_ai::FunctionDefinition {
676 name: tool_name.clone(),
677 description: Some(tool_description),
678 parameters: Some(input_schema),
679 },
680 }];
681
682 self.request_limiter
683 .run(async move {
684 let response = Self::perform_llm_completion(
685 client.clone(),
686 llm_api_token,
687 PerformCompletionParams {
688 provider: client::LanguageModelProvider::OpenAi,
689 model: request.model.clone(),
690 provider_request: RawValue::from_string(serde_json::to_string(
691 &request,
692 )?)?,
693 },
694 )
695 .await?;
696
697 Ok(open_ai::extract_tool_args_from_events(
698 tool_name,
699 Box::pin(response_lines(response)),
700 )
701 .await?
702 .boxed())
703 })
704 .boxed()
705 }
706 CloudModel::Google(_) => {
707 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
708 }
709 CloudModel::Zed(model) => {
710 // All Zed models are OpenAI-based at the time of writing.
711 let mut request = request.into_open_ai(model.id().into(), None);
712 request.tool_choice = Some(open_ai::ToolChoice::Other(
713 open_ai::ToolDefinition::Function {
714 function: open_ai::FunctionDefinition {
715 name: tool_name.clone(),
716 description: None,
717 parameters: None,
718 },
719 },
720 ));
721 request.tools = vec![open_ai::ToolDefinition::Function {
722 function: open_ai::FunctionDefinition {
723 name: tool_name.clone(),
724 description: Some(tool_description),
725 parameters: Some(input_schema),
726 },
727 }];
728
729 self.request_limiter
730 .run(async move {
731 let response = Self::perform_llm_completion(
732 client.clone(),
733 llm_api_token,
734 PerformCompletionParams {
735 provider: client::LanguageModelProvider::Zed,
736 model: request.model.clone(),
737 provider_request: RawValue::from_string(serde_json::to_string(
738 &request,
739 )?)?,
740 },
741 )
742 .await?;
743
744 Ok(open_ai::extract_tool_args_from_events(
745 tool_name,
746 Box::pin(response_lines(response)),
747 )
748 .await?
749 .boxed())
750 })
751 .boxed()
752 }
753 }
754 }
755}
756
757fn response_lines<T: DeserializeOwned>(
758 response: Response<AsyncBody>,
759) -> impl Stream<Item = Result<T>> {
760 futures::stream::try_unfold(
761 (String::new(), BufReader::new(response.into_body())),
762 move |(mut line, mut body)| async {
763 match body.read_line(&mut line).await {
764 Ok(0) => Ok(None),
765 Ok(_) => {
766 let event: T = serde_json::from_str(&line)?;
767 line.clear();
768 Ok(Some((event, (line, body))))
769 }
770 Err(e) => Err(e.into()),
771 }
772 },
773 )
774}
775
776impl LlmApiToken {
777 async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
778 let lock = self.0.upgradable_read().await;
779 if let Some(token) = lock.as_ref() {
780 Ok(token.to_string())
781 } else {
782 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
783 }
784 }
785
786 async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
787 Self::fetch(self.0.write().await, client).await
788 }
789
790 async fn fetch<'a>(
791 mut lock: RwLockWriteGuard<'a, Option<String>>,
792 client: &Arc<Client>,
793 ) -> Result<String> {
794 let response = client.request(proto::GetLlmToken {}).await?;
795 *lock = Some(response.token.clone());
796 Ok(response.token.clone())
797 }
798}
799
800struct ConfigurationView {
801 state: gpui::Model<State>,
802}
803
804impl ConfigurationView {
805 fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
806 self.state.update(cx, |state, cx| {
807 state.authenticate(cx).detach_and_log_err(cx);
808 });
809 cx.notify();
810 }
811
812 fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
813 if self.state.read(cx).has_accepted_terms_of_service(cx) {
814 return None;
815 }
816
817 let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
818
819 let terms_button = Button::new("terms_of_service", "Terms of Service")
820 .style(ButtonStyle::Subtle)
821 .icon(IconName::ExternalLink)
822 .icon_color(Color::Muted)
823 .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
824
825 let text =
826 "In order to use Zed AI, please read and accept our terms and conditions to continue:";
827
828 let form = v_flex()
829 .gap_2()
830 .child(Label::new("Terms and Conditions"))
831 .child(Label::new(text))
832 .child(h_flex().justify_center().child(terms_button))
833 .child(
834 h_flex().justify_center().child(
835 Button::new("accept_terms", "I've read and accept the terms of service")
836 .style(ButtonStyle::Tinted(TintColor::Accent))
837 .disabled(accept_terms_disabled)
838 .on_click({
839 let state = self.state.downgrade();
840 move |_, cx| {
841 state
842 .update(cx, |state, cx| state.accept_terms_of_service(cx))
843 .ok();
844 }
845 }),
846 ),
847 );
848
849 Some(form.into_any())
850 }
851}
852
853impl Render for ConfigurationView {
854 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
855 const ZED_AI_URL: &str = "https://zed.dev/ai";
856 const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
857
858 let is_connected = !self.state.read(cx).is_signed_out();
859 let plan = self.state.read(cx).user_store.read(cx).current_plan();
860 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
861
862 let is_pro = plan == Some(proto::Plan::ZedPro);
863 let subscription_text = Label::new(if is_pro {
864 "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
865 } else {
866 "You have basic access to models from Anthropic through the Zed AI Free plan."
867 });
868 let manage_subscription_button = if is_pro {
869 Some(
870 h_flex().child(
871 Button::new("manage_settings", "Manage Subscription")
872 .style(ButtonStyle::Tinted(TintColor::Accent))
873 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
874 ),
875 )
876 } else if cx.has_flag::<ZedPro>() {
877 Some(
878 h_flex()
879 .gap_2()
880 .child(
881 Button::new("learn_more", "Learn more")
882 .style(ButtonStyle::Subtle)
883 .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
884 )
885 .child(
886 Button::new("upgrade", "Upgrade")
887 .style(ButtonStyle::Subtle)
888 .color(Color::Accent)
889 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
890 ),
891 )
892 } else {
893 None
894 };
895
896 if is_connected {
897 v_flex()
898 .gap_3()
899 .max_w_4_5()
900 .children(self.render_accept_terms(cx))
901 .when(has_accepted_terms, |this| {
902 this.child(subscription_text)
903 .children(manage_subscription_button)
904 })
905 } else {
906 v_flex()
907 .gap_6()
908 .child(Label::new("Use the zed.dev to access language models."))
909 .child(
910 v_flex()
911 .gap_2()
912 .child(
913 Button::new("sign_in", "Sign in")
914 .icon_color(Color::Muted)
915 .icon(IconName::Github)
916 .icon_position(IconPosition::Start)
917 .style(ButtonStyle::Filled)
918 .full_width()
919 .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
920 )
921 .child(
922 div().flex().w_full().items_center().child(
923 Label::new("Sign in to enable collaboration.")
924 .color(Color::Muted)
925 .size(LabelSize::Small),
926 ),
927 ),
928 )
929 }
930 }
931}