1use super::open_ai::count_open_ai_tokens;
2use crate::provider::anthropic::map_to_language_model_completion_events;
3use crate::{
4 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
5 LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
6 LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
7};
8use anthropic::AnthropicError;
9use anyhow::{anyhow, Result};
10use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
11use collections::BTreeMap;
12use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
13use futures::{
14 future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
15 TryStreamExt as _,
16};
17use gpui::{
18 AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
19 Subscription, Task,
20};
21use http_client::{AsyncBody, HttpClient, Method, Response};
22use schemars::JsonSchema;
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use serde_json::value::RawValue;
25use settings::{Settings, SettingsStore};
26use smol::{
27 io::{AsyncReadExt, BufReader},
28 lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
29};
30use std::{
31 future,
32 sync::{Arc, LazyLock},
33};
34use strum::IntoEnumIterator;
35use ui::{prelude::*, TintColor};
36
37use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
38
39use super::anthropic::count_anthropic_tokens;
40
41pub const PROVIDER_ID: &str = "zed.dev";
42pub const PROVIDER_NAME: &str = "Zed";
43
44const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
45 option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
46
47fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
48 static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
49 ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
50 .map(|json| serde_json::from_str(json).unwrap())
51 .unwrap_or_default()
52 });
53 ADDITIONAL_MODELS.as_slice()
54}
55
56#[derive(Default, Clone, Debug, PartialEq)]
57pub struct ZedDotDevSettings {
58 pub available_models: Vec<AvailableModel>,
59}
60
61#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
62#[serde(rename_all = "lowercase")]
63pub enum AvailableProvider {
64 Anthropic,
65 OpenAi,
66 Google,
67}
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
70pub struct AvailableModel {
71 /// The provider of the language model.
72 pub provider: AvailableProvider,
73 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
74 pub name: String,
75 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
76 pub display_name: Option<String>,
77 /// The size of the context window, indicating the maximum number of tokens the model can process.
78 pub max_tokens: usize,
79 /// The maximum number of output tokens allowed by the model.
80 pub max_output_tokens: Option<u32>,
81 /// The maximum number of completion tokens allowed by the model (o1-* only)
82 pub max_completion_tokens: Option<u32>,
83 /// Override this model with a different Anthropic model for tool calls.
84 pub tool_override: Option<String>,
85 /// Indicates whether this custom model supports caching.
86 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
87}
88
89pub struct CloudLanguageModelProvider {
90 client: Arc<Client>,
91 llm_api_token: LlmApiToken,
92 state: gpui::Model<State>,
93 _maintain_client_status: Task<()>,
94}
95
96pub struct State {
97 client: Arc<Client>,
98 user_store: Model<UserStore>,
99 status: client::Status,
100 accept_terms: Option<Task<Result<()>>>,
101 _subscription: Subscription,
102}
103
104impl State {
105 fn is_signed_out(&self) -> bool {
106 self.status.is_signed_out()
107 }
108
109 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
110 let client = self.client.clone();
111 cx.spawn(move |this, mut cx| async move {
112 client.authenticate_and_connect(true, &cx).await?;
113 this.update(&mut cx, |_, cx| cx.notify())
114 })
115 }
116
117 fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
118 self.user_store
119 .read(cx)
120 .current_user_has_accepted_terms()
121 .unwrap_or(false)
122 }
123
124 fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
125 let user_store = self.user_store.clone();
126 self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
127 let _ = user_store
128 .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
129 .await;
130 this.update(&mut cx, |this, cx| {
131 this.accept_terms = None;
132 cx.notify()
133 })
134 }));
135 }
136}
137
138impl CloudLanguageModelProvider {
139 pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
140 let mut status_rx = client.status();
141 let status = *status_rx.borrow();
142
143 let state = cx.new_model(|cx| State {
144 client: client.clone(),
145 user_store,
146 status,
147 accept_terms: None,
148 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
149 cx.notify();
150 }),
151 });
152
153 let state_ref = state.downgrade();
154 let maintain_client_status = cx.spawn(|mut cx| async move {
155 while let Some(status) = status_rx.next().await {
156 if let Some(this) = state_ref.upgrade() {
157 _ = this.update(&mut cx, |this, cx| {
158 if this.status != status {
159 this.status = status;
160 cx.notify();
161 }
162 });
163 } else {
164 break;
165 }
166 }
167 });
168
169 Self {
170 client,
171 state,
172 llm_api_token: LlmApiToken::default(),
173 _maintain_client_status: maintain_client_status,
174 }
175 }
176}
177
178impl LanguageModelProviderState for CloudLanguageModelProvider {
179 type ObservableEntity = State;
180
181 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
182 Some(self.state.clone())
183 }
184}
185
186impl LanguageModelProvider for CloudLanguageModelProvider {
187 fn id(&self) -> LanguageModelProviderId {
188 LanguageModelProviderId(PROVIDER_ID.into())
189 }
190
191 fn name(&self) -> LanguageModelProviderName {
192 LanguageModelProviderName(PROVIDER_NAME.into())
193 }
194
195 fn icon(&self) -> IconName {
196 IconName::AiZed
197 }
198
199 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
200 let mut models = BTreeMap::default();
201
202 if cx.is_staff() {
203 for model in anthropic::Model::iter() {
204 if !matches!(model, anthropic::Model::Custom { .. }) {
205 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
206 }
207 }
208 for model in open_ai::Model::iter() {
209 if !matches!(model, open_ai::Model::Custom { .. }) {
210 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
211 }
212 }
213 for model in google_ai::Model::iter() {
214 if !matches!(model, google_ai::Model::Custom { .. }) {
215 models.insert(model.id().to_string(), CloudModel::Google(model));
216 }
217 }
218 for model in ZedModel::iter() {
219 models.insert(model.id().to_string(), CloudModel::Zed(model));
220 }
221 } else {
222 models.insert(
223 anthropic::Model::Claude3_5Sonnet.id().to_string(),
224 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
225 );
226 }
227
228 let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
229 zed_cloud_provider_additional_models()
230 } else {
231 &[]
232 };
233
234 // Override with available models from settings
235 for model in AllLanguageModelSettings::get_global(cx)
236 .zed_dot_dev
237 .available_models
238 .iter()
239 .chain(llm_closed_beta_models)
240 .cloned()
241 {
242 let model = match model.provider {
243 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
244 name: model.name.clone(),
245 display_name: model.display_name.clone(),
246 max_tokens: model.max_tokens,
247 tool_override: model.tool_override.clone(),
248 cache_configuration: model.cache_configuration.as_ref().map(|config| {
249 anthropic::AnthropicModelCacheConfiguration {
250 max_cache_anchors: config.max_cache_anchors,
251 should_speculate: config.should_speculate,
252 min_total_token: config.min_total_token,
253 }
254 }),
255 max_output_tokens: model.max_output_tokens,
256 }),
257 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
258 name: model.name.clone(),
259 display_name: model.display_name.clone(),
260 max_tokens: model.max_tokens,
261 max_output_tokens: model.max_output_tokens,
262 max_completion_tokens: model.max_completion_tokens,
263 }),
264 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
265 name: model.name.clone(),
266 display_name: model.display_name.clone(),
267 max_tokens: model.max_tokens,
268 }),
269 };
270 models.insert(model.id().to_string(), model.clone());
271 }
272
273 models
274 .into_values()
275 .map(|model| {
276 Arc::new(CloudLanguageModel {
277 id: LanguageModelId::from(model.id().to_string()),
278 model,
279 llm_api_token: self.llm_api_token.clone(),
280 client: self.client.clone(),
281 request_limiter: RateLimiter::new(4),
282 }) as Arc<dyn LanguageModel>
283 })
284 .collect()
285 }
286
287 fn is_authenticated(&self, cx: &AppContext) -> bool {
288 !self.state.read(cx).is_signed_out()
289 }
290
291 fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
292 Task::ready(Ok(()))
293 }
294
295 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
296 cx.new_view(|_cx| ConfigurationView {
297 state: self.state.clone(),
298 })
299 .into()
300 }
301
302 fn must_accept_terms(&self, cx: &AppContext) -> bool {
303 !self.state.read(cx).has_accepted_terms_of_service(cx)
304 }
305
306 fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
307 let state = self.state.read(cx);
308
309 let terms = [(
310 "terms_of_service",
311 "Terms of Service",
312 "https://zed.dev/terms-of-service",
313 )]
314 .map(|(id, label, url)| {
315 Button::new(id, label)
316 .style(ButtonStyle::Subtle)
317 .icon(IconName::ExternalLink)
318 .icon_size(IconSize::XSmall)
319 .icon_color(Color::Muted)
320 .on_click(move |_, cx| cx.open_url(url))
321 });
322
323 if state.has_accepted_terms_of_service(cx) {
324 None
325 } else {
326 let disabled = state.accept_terms.is_some();
327 Some(
328 v_flex()
329 .gap_2()
330 .child(
331 v_flex()
332 .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
333 .child(
334 Label::new(
335 "Please read and accept our terms and conditions to continue.",
336 )
337 .size(LabelSize::Small),
338 ),
339 )
340 .child(v_flex().gap_1().children(terms))
341 .child(
342 h_flex().justify_end().child(
343 Button::new("accept_terms", "I've read it and accept it")
344 .disabled(disabled)
345 .on_click({
346 let state = self.state.downgrade();
347 move |_, cx| {
348 state
349 .update(cx, |state, cx| {
350 state.accept_terms_of_service(cx)
351 })
352 .ok();
353 }
354 }),
355 ),
356 )
357 .into_any(),
358 )
359 }
360 }
361
362 fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
363 Task::ready(Ok(()))
364 }
365}
366
367pub struct CloudLanguageModel {
368 id: LanguageModelId,
369 model: CloudModel,
370 llm_api_token: LlmApiToken,
371 client: Arc<Client>,
372 request_limiter: RateLimiter,
373}
374
375#[derive(Clone, Default)]
376struct LlmApiToken(Arc<RwLock<Option<String>>>);
377
378impl CloudLanguageModel {
379 async fn perform_llm_completion(
380 client: Arc<Client>,
381 llm_api_token: LlmApiToken,
382 body: PerformCompletionParams,
383 ) -> Result<Response<AsyncBody>> {
384 let http_client = &client.http_client();
385
386 let mut token = llm_api_token.acquire(&client).await?;
387 let mut did_retry = false;
388
389 let response = loop {
390 let request = http_client::Request::builder()
391 .method(Method::POST)
392 .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
393 .header("Content-Type", "application/json")
394 .header("Authorization", format!("Bearer {token}"))
395 .body(serde_json::to_string(&body)?.into())?;
396 let mut response = http_client.send(request).await?;
397 if response.status().is_success() {
398 break response;
399 } else if !did_retry
400 && response
401 .headers()
402 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
403 .is_some()
404 {
405 did_retry = true;
406 token = llm_api_token.refresh(&client).await?;
407 } else {
408 let mut body = String::new();
409 response.body_mut().read_to_string(&mut body).await?;
410 break Err(anyhow!(
411 "cloud language model completion failed with status {}: {body}",
412 response.status()
413 ))?;
414 }
415 };
416
417 Ok(response)
418 }
419}
420
421impl LanguageModel for CloudLanguageModel {
422 fn id(&self) -> LanguageModelId {
423 self.id.clone()
424 }
425
426 fn name(&self) -> LanguageModelName {
427 LanguageModelName::from(self.model.display_name().to_string())
428 }
429
430 fn icon(&self) -> Option<IconName> {
431 self.model.icon()
432 }
433
434 fn provider_id(&self) -> LanguageModelProviderId {
435 LanguageModelProviderId(PROVIDER_ID.into())
436 }
437
438 fn provider_name(&self) -> LanguageModelProviderName {
439 LanguageModelProviderName(PROVIDER_NAME.into())
440 }
441
442 fn telemetry_id(&self) -> String {
443 format!("zed.dev/{}", self.model.id())
444 }
445
446 fn availability(&self) -> LanguageModelAvailability {
447 self.model.availability()
448 }
449
450 fn max_token_count(&self) -> usize {
451 self.model.max_token_count()
452 }
453
454 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
455 match &self.model {
456 CloudModel::Anthropic(model) => {
457 model
458 .cache_configuration()
459 .map(|cache| LanguageModelCacheConfiguration {
460 max_cache_anchors: cache.max_cache_anchors,
461 should_speculate: cache.should_speculate,
462 min_total_token: cache.min_total_token,
463 })
464 }
465 CloudModel::OpenAi(_) | CloudModel::Google(_) | CloudModel::Zed(_) => None,
466 }
467 }
468
469 fn count_tokens(
470 &self,
471 request: LanguageModelRequest,
472 cx: &AppContext,
473 ) -> BoxFuture<'static, Result<usize>> {
474 match self.model.clone() {
475 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
476 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
477 CloudModel::Google(model) => {
478 let client = self.client.clone();
479 let request = request.into_google(model.id().into());
480 let request = google_ai::CountTokensRequest {
481 contents: request.contents,
482 };
483 async move {
484 let request = serde_json::to_string(&request)?;
485 let response = client
486 .request(proto::CountLanguageModelTokens {
487 provider: proto::LanguageModelProvider::Google as i32,
488 request,
489 })
490 .await?;
491 Ok(response.token_count as usize)
492 }
493 .boxed()
494 }
495 CloudModel::Zed(_) => {
496 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
497 }
498 }
499 }
500
501 fn stream_completion(
502 &self,
503 request: LanguageModelRequest,
504 _cx: &AsyncAppContext,
505 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
506 match &self.model {
507 CloudModel::Anthropic(model) => {
508 let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
509 let client = self.client.clone();
510 let llm_api_token = self.llm_api_token.clone();
511 let future = self.request_limiter.stream(async move {
512 let response = Self::perform_llm_completion(
513 client.clone(),
514 llm_api_token,
515 PerformCompletionParams {
516 provider: client::LanguageModelProvider::Anthropic,
517 model: request.model.clone(),
518 provider_request: RawValue::from_string(serde_json::to_string(
519 &request,
520 )?)?,
521 },
522 )
523 .await?;
524 Ok(map_to_language_model_completion_events(Box::pin(
525 response_lines(response).map_err(AnthropicError::Other),
526 )))
527 });
528 async move { Ok(future.await?.boxed()) }.boxed()
529 }
530 CloudModel::OpenAi(model) => {
531 let client = self.client.clone();
532 let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
533 let llm_api_token = self.llm_api_token.clone();
534 let future = self.request_limiter.stream(async move {
535 let response = Self::perform_llm_completion(
536 client.clone(),
537 llm_api_token,
538 PerformCompletionParams {
539 provider: client::LanguageModelProvider::OpenAi,
540 model: request.model.clone(),
541 provider_request: RawValue::from_string(serde_json::to_string(
542 &request,
543 )?)?,
544 },
545 )
546 .await?;
547 Ok(open_ai::extract_text_from_events(response_lines(response)))
548 });
549 async move {
550 Ok(future
551 .await?
552 .map(|result| result.map(LanguageModelCompletionEvent::Text))
553 .boxed())
554 }
555 .boxed()
556 }
557 CloudModel::Google(model) => {
558 let client = self.client.clone();
559 let request = request.into_google(model.id().into());
560 let llm_api_token = self.llm_api_token.clone();
561 let future = self.request_limiter.stream(async move {
562 let response = Self::perform_llm_completion(
563 client.clone(),
564 llm_api_token,
565 PerformCompletionParams {
566 provider: client::LanguageModelProvider::Google,
567 model: request.model.clone(),
568 provider_request: RawValue::from_string(serde_json::to_string(
569 &request,
570 )?)?,
571 },
572 )
573 .await?;
574 Ok(google_ai::extract_text_from_events(response_lines(
575 response,
576 )))
577 });
578 async move {
579 Ok(future
580 .await?
581 .map(|result| result.map(LanguageModelCompletionEvent::Text))
582 .boxed())
583 }
584 .boxed()
585 }
586 CloudModel::Zed(model) => {
587 let client = self.client.clone();
588 let mut request = request.into_open_ai(model.id().into(), None);
589 request.max_tokens = Some(4000);
590 let llm_api_token = self.llm_api_token.clone();
591 let future = self.request_limiter.stream(async move {
592 let response = Self::perform_llm_completion(
593 client.clone(),
594 llm_api_token,
595 PerformCompletionParams {
596 provider: client::LanguageModelProvider::Zed,
597 model: request.model.clone(),
598 provider_request: RawValue::from_string(serde_json::to_string(
599 &request,
600 )?)?,
601 },
602 )
603 .await?;
604 Ok(open_ai::extract_text_from_events(response_lines(response)))
605 });
606 async move {
607 Ok(future
608 .await?
609 .map(|result| result.map(LanguageModelCompletionEvent::Text))
610 .boxed())
611 }
612 .boxed()
613 }
614 }
615 }
616
617 fn use_any_tool(
618 &self,
619 request: LanguageModelRequest,
620 tool_name: String,
621 tool_description: String,
622 input_schema: serde_json::Value,
623 _cx: &AsyncAppContext,
624 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
625 let client = self.client.clone();
626 let llm_api_token = self.llm_api_token.clone();
627
628 match &self.model {
629 CloudModel::Anthropic(model) => {
630 let mut request =
631 request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
632 request.tool_choice = Some(anthropic::ToolChoice::Tool {
633 name: tool_name.clone(),
634 });
635 request.tools = vec![anthropic::Tool {
636 name: tool_name.clone(),
637 description: tool_description,
638 input_schema,
639 }];
640
641 self.request_limiter
642 .run(async move {
643 let response = Self::perform_llm_completion(
644 client.clone(),
645 llm_api_token,
646 PerformCompletionParams {
647 provider: client::LanguageModelProvider::Anthropic,
648 model: request.model.clone(),
649 provider_request: RawValue::from_string(serde_json::to_string(
650 &request,
651 )?)?,
652 },
653 )
654 .await?;
655
656 Ok(anthropic::extract_tool_args_from_events(
657 tool_name,
658 Box::pin(response_lines(response)),
659 )
660 .await?
661 .boxed())
662 })
663 .boxed()
664 }
665 CloudModel::OpenAi(model) => {
666 let mut request =
667 request.into_open_ai(model.id().into(), model.max_output_tokens());
668 request.tool_choice = Some(open_ai::ToolChoice::Other(
669 open_ai::ToolDefinition::Function {
670 function: open_ai::FunctionDefinition {
671 name: tool_name.clone(),
672 description: None,
673 parameters: None,
674 },
675 },
676 ));
677 request.tools = vec![open_ai::ToolDefinition::Function {
678 function: open_ai::FunctionDefinition {
679 name: tool_name.clone(),
680 description: Some(tool_description),
681 parameters: Some(input_schema),
682 },
683 }];
684
685 self.request_limiter
686 .run(async move {
687 let response = Self::perform_llm_completion(
688 client.clone(),
689 llm_api_token,
690 PerformCompletionParams {
691 provider: client::LanguageModelProvider::OpenAi,
692 model: request.model.clone(),
693 provider_request: RawValue::from_string(serde_json::to_string(
694 &request,
695 )?)?,
696 },
697 )
698 .await?;
699
700 Ok(open_ai::extract_tool_args_from_events(
701 tool_name,
702 Box::pin(response_lines(response)),
703 )
704 .await?
705 .boxed())
706 })
707 .boxed()
708 }
709 CloudModel::Google(_) => {
710 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
711 }
712 CloudModel::Zed(model) => {
713 // All Zed models are OpenAI-based at the time of writing.
714 let mut request = request.into_open_ai(model.id().into(), None);
715 request.tool_choice = Some(open_ai::ToolChoice::Other(
716 open_ai::ToolDefinition::Function {
717 function: open_ai::FunctionDefinition {
718 name: tool_name.clone(),
719 description: None,
720 parameters: None,
721 },
722 },
723 ));
724 request.tools = vec![open_ai::ToolDefinition::Function {
725 function: open_ai::FunctionDefinition {
726 name: tool_name.clone(),
727 description: Some(tool_description),
728 parameters: Some(input_schema),
729 },
730 }];
731
732 self.request_limiter
733 .run(async move {
734 let response = Self::perform_llm_completion(
735 client.clone(),
736 llm_api_token,
737 PerformCompletionParams {
738 provider: client::LanguageModelProvider::Zed,
739 model: request.model.clone(),
740 provider_request: RawValue::from_string(serde_json::to_string(
741 &request,
742 )?)?,
743 },
744 )
745 .await?;
746
747 Ok(open_ai::extract_tool_args_from_events(
748 tool_name,
749 Box::pin(response_lines(response)),
750 )
751 .await?
752 .boxed())
753 })
754 .boxed()
755 }
756 }
757 }
758}
759
760fn response_lines<T: DeserializeOwned>(
761 response: Response<AsyncBody>,
762) -> impl Stream<Item = Result<T>> {
763 futures::stream::try_unfold(
764 (String::new(), BufReader::new(response.into_body())),
765 move |(mut line, mut body)| async {
766 match body.read_line(&mut line).await {
767 Ok(0) => Ok(None),
768 Ok(_) => {
769 let event: T = serde_json::from_str(&line)?;
770 line.clear();
771 Ok(Some((event, (line, body))))
772 }
773 Err(e) => Err(e.into()),
774 }
775 },
776 )
777}
778
779impl LlmApiToken {
780 async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
781 let lock = self.0.upgradable_read().await;
782 if let Some(token) = lock.as_ref() {
783 Ok(token.to_string())
784 } else {
785 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
786 }
787 }
788
789 async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
790 Self::fetch(self.0.write().await, client).await
791 }
792
793 async fn fetch<'a>(
794 mut lock: RwLockWriteGuard<'a, Option<String>>,
795 client: &Arc<Client>,
796 ) -> Result<String> {
797 let response = client.request(proto::GetLlmToken {}).await?;
798 *lock = Some(response.token.clone());
799 Ok(response.token.clone())
800 }
801}
802
803struct ConfigurationView {
804 state: gpui::Model<State>,
805}
806
807impl ConfigurationView {
808 fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
809 self.state.update(cx, |state, cx| {
810 state.authenticate(cx).detach_and_log_err(cx);
811 });
812 cx.notify();
813 }
814
815 fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
816 if self.state.read(cx).has_accepted_terms_of_service(cx) {
817 return None;
818 }
819
820 let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
821
822 let terms_button = Button::new("terms_of_service", "Terms of Service")
823 .style(ButtonStyle::Subtle)
824 .icon(IconName::ExternalLink)
825 .icon_color(Color::Muted)
826 .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
827
828 let text =
829 "In order to use Zed AI, please read and accept our terms and conditions to continue:";
830
831 let form = v_flex()
832 .gap_2()
833 .child(Label::new("Terms and Conditions"))
834 .child(Label::new(text))
835 .child(h_flex().justify_center().child(terms_button))
836 .child(
837 h_flex().justify_center().child(
838 Button::new("accept_terms", "I've read and accept the terms of service")
839 .style(ButtonStyle::Tinted(TintColor::Accent))
840 .disabled(accept_terms_disabled)
841 .on_click({
842 let state = self.state.downgrade();
843 move |_, cx| {
844 state
845 .update(cx, |state, cx| state.accept_terms_of_service(cx))
846 .ok();
847 }
848 }),
849 ),
850 );
851
852 Some(form.into_any())
853 }
854}
855
856impl Render for ConfigurationView {
857 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
858 const ZED_AI_URL: &str = "https://zed.dev/ai";
859 const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
860
861 let is_connected = !self.state.read(cx).is_signed_out();
862 let plan = self.state.read(cx).user_store.read(cx).current_plan();
863 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
864
865 let is_pro = plan == Some(proto::Plan::ZedPro);
866 let subscription_text = Label::new(if is_pro {
867 "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
868 } else {
869 "You have basic access to models from Anthropic through the Zed AI Free plan."
870 });
871 let manage_subscription_button = if is_pro {
872 Some(
873 h_flex().child(
874 Button::new("manage_settings", "Manage Subscription")
875 .style(ButtonStyle::Tinted(TintColor::Accent))
876 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
877 ),
878 )
879 } else if cx.has_flag::<ZedPro>() {
880 Some(
881 h_flex()
882 .gap_2()
883 .child(
884 Button::new("learn_more", "Learn more")
885 .style(ButtonStyle::Subtle)
886 .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
887 )
888 .child(
889 Button::new("upgrade", "Upgrade")
890 .style(ButtonStyle::Subtle)
891 .color(Color::Accent)
892 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
893 ),
894 )
895 } else {
896 None
897 };
898
899 if is_connected {
900 v_flex()
901 .gap_3()
902 .max_w_4_5()
903 .children(self.render_accept_terms(cx))
904 .when(has_accepted_terms, |this| {
905 this.child(subscription_text)
906 .children(manage_subscription_button)
907 })
908 } else {
909 v_flex()
910 .gap_6()
911 .child(Label::new("Use the zed.dev to access language models."))
912 .child(
913 v_flex()
914 .gap_2()
915 .child(
916 Button::new("sign_in", "Sign in")
917 .icon_color(Color::Muted)
918 .icon(IconName::Github)
919 .icon_position(IconPosition::Start)
920 .style(ButtonStyle::Filled)
921 .full_width()
922 .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
923 )
924 .child(
925 div().flex().w_full().items_center().child(
926 Label::new("Sign in to enable collaboration.")
927 .color(Color::Muted)
928 .size(LabelSize::Small),
929 ),
930 ),
931 )
932 }
933 }
934}