1use super::open_ai::count_open_ai_tokens;
2use crate::provider::anthropic::map_to_language_model_completion_events;
3use crate::{
4 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
5 LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
6 LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
7};
8use anthropic::AnthropicError;
9use anyhow::{anyhow, Result};
10use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
11use collections::BTreeMap;
12use feature_flags::{FeatureFlagAppExt, LlmClosedBeta, ZedPro};
13use futures::{
14 future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
15 TryStreamExt as _,
16};
17use gpui::{
18 AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
19 Subscription, Task,
20};
21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response};
22use schemars::JsonSchema;
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use serde_json::value::RawValue;
25use settings::{Settings, SettingsStore};
26use smol::{
27 io::{AsyncReadExt, BufReader},
28 lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
29};
30use std::time::Duration;
31use std::{
32 future,
33 sync::{Arc, LazyLock},
34};
35use strum::IntoEnumIterator;
36use ui::{prelude::*, TintColor};
37
38use crate::{LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider};
39
40use super::anthropic::count_anthropic_tokens;
41
42pub const PROVIDER_ID: &str = "zed.dev";
43pub const PROVIDER_NAME: &str = "Zed";
44
45const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
46 option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
47
48fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
49 static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
50 ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
51 .map(|json| serde_json::from_str(json).unwrap())
52 .unwrap_or_default()
53 });
54 ADDITIONAL_MODELS.as_slice()
55}
56
57#[derive(Default, Clone, Debug, PartialEq)]
58pub struct ZedDotDevSettings {
59 pub available_models: Vec<AvailableModel>,
60 pub low_speed_timeout: Option<Duration>,
61}
62
63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
64#[serde(rename_all = "lowercase")]
65pub enum AvailableProvider {
66 Anthropic,
67 OpenAi,
68 Google,
69}
70
71#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
72pub struct AvailableModel {
73 /// The provider of the language model.
74 pub provider: AvailableProvider,
75 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
76 pub name: String,
77 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
78 pub display_name: Option<String>,
79 /// The size of the context window, indicating the maximum number of tokens the model can process.
80 pub max_tokens: usize,
81 /// The maximum number of output tokens allowed by the model.
82 pub max_output_tokens: Option<u32>,
83 /// The maximum number of completion tokens allowed by the model (o1-* only)
84 pub max_completion_tokens: Option<u32>,
85 /// Override this model with a different Anthropic model for tool calls.
86 pub tool_override: Option<String>,
87 /// Indicates whether this custom model supports caching.
88 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
89 /// The default temperature to use for this model.
90 pub default_temperature: Option<f32>,
91}
92
93pub struct CloudLanguageModelProvider {
94 client: Arc<Client>,
95 llm_api_token: LlmApiToken,
96 state: gpui::Model<State>,
97 _maintain_client_status: Task<()>,
98}
99
100pub struct State {
101 client: Arc<Client>,
102 user_store: Model<UserStore>,
103 status: client::Status,
104 accept_terms: Option<Task<Result<()>>>,
105 _subscription: Subscription,
106}
107
108impl State {
109 fn is_signed_out(&self) -> bool {
110 self.status.is_signed_out()
111 }
112
113 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
114 let client = self.client.clone();
115 cx.spawn(move |this, mut cx| async move {
116 client.authenticate_and_connect(true, &cx).await?;
117 this.update(&mut cx, |_, cx| cx.notify())
118 })
119 }
120
121 fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
122 self.user_store
123 .read(cx)
124 .current_user_has_accepted_terms()
125 .unwrap_or(false)
126 }
127
128 fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
129 let user_store = self.user_store.clone();
130 self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
131 let _ = user_store
132 .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
133 .await;
134 this.update(&mut cx, |this, cx| {
135 this.accept_terms = None;
136 cx.notify()
137 })
138 }));
139 }
140}
141
142impl CloudLanguageModelProvider {
143 pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
144 let mut status_rx = client.status();
145 let status = *status_rx.borrow();
146
147 let state = cx.new_model(|cx| State {
148 client: client.clone(),
149 user_store,
150 status,
151 accept_terms: None,
152 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
153 cx.notify();
154 }),
155 });
156
157 let state_ref = state.downgrade();
158 let maintain_client_status = cx.spawn(|mut cx| async move {
159 while let Some(status) = status_rx.next().await {
160 if let Some(this) = state_ref.upgrade() {
161 _ = this.update(&mut cx, |this, cx| {
162 if this.status != status {
163 this.status = status;
164 cx.notify();
165 }
166 });
167 } else {
168 break;
169 }
170 }
171 });
172
173 Self {
174 client,
175 state,
176 llm_api_token: LlmApiToken::default(),
177 _maintain_client_status: maintain_client_status,
178 }
179 }
180}
181
182impl LanguageModelProviderState for CloudLanguageModelProvider {
183 type ObservableEntity = State;
184
185 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
186 Some(self.state.clone())
187 }
188}
189
190impl LanguageModelProvider for CloudLanguageModelProvider {
191 fn id(&self) -> LanguageModelProviderId {
192 LanguageModelProviderId(PROVIDER_ID.into())
193 }
194
195 fn name(&self) -> LanguageModelProviderName {
196 LanguageModelProviderName(PROVIDER_NAME.into())
197 }
198
199 fn icon(&self) -> IconName {
200 IconName::AiZed
201 }
202
203 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
204 let mut models = BTreeMap::default();
205
206 if cx.is_staff() {
207 for model in anthropic::Model::iter() {
208 if !matches!(model, anthropic::Model::Custom { .. }) {
209 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
210 }
211 }
212 for model in open_ai::Model::iter() {
213 if !matches!(model, open_ai::Model::Custom { .. }) {
214 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
215 }
216 }
217 for model in google_ai::Model::iter() {
218 if !matches!(model, google_ai::Model::Custom { .. }) {
219 models.insert(model.id().to_string(), CloudModel::Google(model));
220 }
221 }
222 for model in ZedModel::iter() {
223 models.insert(model.id().to_string(), CloudModel::Zed(model));
224 }
225 } else {
226 models.insert(
227 anthropic::Model::Claude3_5Sonnet.id().to_string(),
228 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
229 );
230 }
231
232 let llm_closed_beta_models = if cx.has_flag::<LlmClosedBeta>() {
233 zed_cloud_provider_additional_models()
234 } else {
235 &[]
236 };
237
238 // Override with available models from settings
239 for model in AllLanguageModelSettings::get_global(cx)
240 .zed_dot_dev
241 .available_models
242 .iter()
243 .chain(llm_closed_beta_models)
244 .cloned()
245 {
246 let model = match model.provider {
247 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
248 name: model.name.clone(),
249 display_name: model.display_name.clone(),
250 max_tokens: model.max_tokens,
251 tool_override: model.tool_override.clone(),
252 cache_configuration: model.cache_configuration.as_ref().map(|config| {
253 anthropic::AnthropicModelCacheConfiguration {
254 max_cache_anchors: config.max_cache_anchors,
255 should_speculate: config.should_speculate,
256 min_total_token: config.min_total_token,
257 }
258 }),
259 default_temperature: model.default_temperature,
260 max_output_tokens: model.max_output_tokens,
261 }),
262 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
263 name: model.name.clone(),
264 display_name: model.display_name.clone(),
265 max_tokens: model.max_tokens,
266 max_output_tokens: model.max_output_tokens,
267 max_completion_tokens: model.max_completion_tokens,
268 }),
269 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
270 name: model.name.clone(),
271 display_name: model.display_name.clone(),
272 max_tokens: model.max_tokens,
273 }),
274 };
275 models.insert(model.id().to_string(), model.clone());
276 }
277
278 models
279 .into_values()
280 .map(|model| {
281 Arc::new(CloudLanguageModel {
282 id: LanguageModelId::from(model.id().to_string()),
283 model,
284 llm_api_token: self.llm_api_token.clone(),
285 client: self.client.clone(),
286 request_limiter: RateLimiter::new(4),
287 }) as Arc<dyn LanguageModel>
288 })
289 .collect()
290 }
291
292 fn is_authenticated(&self, cx: &AppContext) -> bool {
293 !self.state.read(cx).is_signed_out()
294 }
295
296 fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
297 Task::ready(Ok(()))
298 }
299
300 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
301 cx.new_view(|_cx| ConfigurationView {
302 state: self.state.clone(),
303 })
304 .into()
305 }
306
307 fn must_accept_terms(&self, cx: &AppContext) -> bool {
308 !self.state.read(cx).has_accepted_terms_of_service(cx)
309 }
310
311 fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
312 let state = self.state.read(cx);
313
314 let terms = [(
315 "terms_of_service",
316 "Terms of Service",
317 "https://zed.dev/terms-of-service",
318 )]
319 .map(|(id, label, url)| {
320 Button::new(id, label)
321 .style(ButtonStyle::Subtle)
322 .icon(IconName::ExternalLink)
323 .icon_size(IconSize::XSmall)
324 .icon_color(Color::Muted)
325 .on_click(move |_, cx| cx.open_url(url))
326 });
327
328 if state.has_accepted_terms_of_service(cx) {
329 None
330 } else {
331 let disabled = state.accept_terms.is_some();
332 Some(
333 v_flex()
334 .gap_2()
335 .child(
336 v_flex()
337 .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
338 .child(
339 Label::new(
340 "Please read and accept our terms and conditions to continue.",
341 )
342 .size(LabelSize::Small),
343 ),
344 )
345 .child(v_flex().gap_1().children(terms))
346 .child(
347 h_flex().justify_end().child(
348 Button::new("accept_terms", "I've read it and accept it")
349 .disabled(disabled)
350 .on_click({
351 let state = self.state.downgrade();
352 move |_, cx| {
353 state
354 .update(cx, |state, cx| {
355 state.accept_terms_of_service(cx)
356 })
357 .ok();
358 }
359 }),
360 ),
361 )
362 .into_any(),
363 )
364 }
365 }
366
367 fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
368 Task::ready(Ok(()))
369 }
370}
371
372pub struct CloudLanguageModel {
373 id: LanguageModelId,
374 model: CloudModel,
375 llm_api_token: LlmApiToken,
376 client: Arc<Client>,
377 request_limiter: RateLimiter,
378}
379
380#[derive(Clone, Default)]
381struct LlmApiToken(Arc<RwLock<Option<String>>>);
382
383impl CloudLanguageModel {
384 async fn perform_llm_completion(
385 client: Arc<Client>,
386 llm_api_token: LlmApiToken,
387 body: PerformCompletionParams,
388 low_speed_timeout: Option<Duration>,
389 ) -> Result<Response<AsyncBody>> {
390 let http_client = &client.http_client();
391
392 let mut token = llm_api_token.acquire(&client).await?;
393 let mut did_retry = false;
394
395 let response = loop {
396 let mut request_builder = http_client::Request::builder();
397 if let Some(low_speed_timeout) = low_speed_timeout {
398 request_builder = request_builder.read_timeout(low_speed_timeout);
399 };
400 let request = request_builder
401 .method(Method::POST)
402 .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
403 .header("Content-Type", "application/json")
404 .header("Authorization", format!("Bearer {token}"))
405 .body(serde_json::to_string(&body)?.into())?;
406 let mut response = http_client.send(request).await?;
407 if response.status().is_success() {
408 break response;
409 } else if !did_retry
410 && response
411 .headers()
412 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
413 .is_some()
414 {
415 did_retry = true;
416 token = llm_api_token.refresh(&client).await?;
417 } else {
418 let mut body = String::new();
419 response.body_mut().read_to_string(&mut body).await?;
420 break Err(anyhow!(
421 "cloud language model completion failed with status {}: {body}",
422 response.status()
423 ))?;
424 }
425 };
426
427 Ok(response)
428 }
429}
430
431impl LanguageModel for CloudLanguageModel {
432 fn id(&self) -> LanguageModelId {
433 self.id.clone()
434 }
435
436 fn name(&self) -> LanguageModelName {
437 LanguageModelName::from(self.model.display_name().to_string())
438 }
439
440 fn icon(&self) -> Option<IconName> {
441 self.model.icon()
442 }
443
444 fn provider_id(&self) -> LanguageModelProviderId {
445 LanguageModelProviderId(PROVIDER_ID.into())
446 }
447
448 fn provider_name(&self) -> LanguageModelProviderName {
449 LanguageModelProviderName(PROVIDER_NAME.into())
450 }
451
452 fn telemetry_id(&self) -> String {
453 format!("zed.dev/{}", self.model.id())
454 }
455
456 fn availability(&self) -> LanguageModelAvailability {
457 self.model.availability()
458 }
459
460 fn max_token_count(&self) -> usize {
461 self.model.max_token_count()
462 }
463
464 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
465 match &self.model {
466 CloudModel::Anthropic(model) => {
467 model
468 .cache_configuration()
469 .map(|cache| LanguageModelCacheConfiguration {
470 max_cache_anchors: cache.max_cache_anchors,
471 should_speculate: cache.should_speculate,
472 min_total_token: cache.min_total_token,
473 })
474 }
475 CloudModel::OpenAi(_) | CloudModel::Google(_) | CloudModel::Zed(_) => None,
476 }
477 }
478
479 fn count_tokens(
480 &self,
481 request: LanguageModelRequest,
482 cx: &AppContext,
483 ) -> BoxFuture<'static, Result<usize>> {
484 match self.model.clone() {
485 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
486 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
487 CloudModel::Google(model) => {
488 let client = self.client.clone();
489 let request = request.into_google(model.id().into());
490 let request = google_ai::CountTokensRequest {
491 contents: request.contents,
492 };
493 async move {
494 let request = serde_json::to_string(&request)?;
495 let response = client
496 .request(proto::CountLanguageModelTokens {
497 provider: proto::LanguageModelProvider::Google as i32,
498 request,
499 })
500 .await?;
501 Ok(response.token_count as usize)
502 }
503 .boxed()
504 }
505 CloudModel::Zed(_) => {
506 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
507 }
508 }
509 }
510
511 fn stream_completion(
512 &self,
513 request: LanguageModelRequest,
514 cx: &AsyncAppContext,
515 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
516 let openai_low_speed_timeout =
517 AllLanguageModelSettings::try_read_global(cx, |s| s.openai.low_speed_timeout.unwrap());
518
519 match &self.model {
520 CloudModel::Anthropic(model) => {
521 let request = request.into_anthropic(
522 model.id().into(),
523 model.default_temperature(),
524 model.max_output_tokens(),
525 );
526 let client = self.client.clone();
527 let llm_api_token = self.llm_api_token.clone();
528 let future = self.request_limiter.stream(async move {
529 let response = Self::perform_llm_completion(
530 client.clone(),
531 llm_api_token,
532 PerformCompletionParams {
533 provider: client::LanguageModelProvider::Anthropic,
534 model: request.model.clone(),
535 provider_request: RawValue::from_string(serde_json::to_string(
536 &request,
537 )?)?,
538 },
539 None,
540 )
541 .await?;
542 Ok(map_to_language_model_completion_events(Box::pin(
543 response_lines(response).map_err(AnthropicError::Other),
544 )))
545 });
546 async move { Ok(future.await?.boxed()) }.boxed()
547 }
548 CloudModel::OpenAi(model) => {
549 let client = self.client.clone();
550 let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
551 let llm_api_token = self.llm_api_token.clone();
552 let future = self.request_limiter.stream(async move {
553 let response = Self::perform_llm_completion(
554 client.clone(),
555 llm_api_token,
556 PerformCompletionParams {
557 provider: client::LanguageModelProvider::OpenAi,
558 model: request.model.clone(),
559 provider_request: RawValue::from_string(serde_json::to_string(
560 &request,
561 )?)?,
562 },
563 openai_low_speed_timeout,
564 )
565 .await?;
566 Ok(open_ai::extract_text_from_events(response_lines(response)))
567 });
568 async move {
569 Ok(future
570 .await?
571 .map(|result| result.map(LanguageModelCompletionEvent::Text))
572 .boxed())
573 }
574 .boxed()
575 }
576 CloudModel::Google(model) => {
577 let client = self.client.clone();
578 let request = request.into_google(model.id().into());
579 let llm_api_token = self.llm_api_token.clone();
580 let future = self.request_limiter.stream(async move {
581 let response = Self::perform_llm_completion(
582 client.clone(),
583 llm_api_token,
584 PerformCompletionParams {
585 provider: client::LanguageModelProvider::Google,
586 model: request.model.clone(),
587 provider_request: RawValue::from_string(serde_json::to_string(
588 &request,
589 )?)?,
590 },
591 None,
592 )
593 .await?;
594 Ok(google_ai::extract_text_from_events(response_lines(
595 response,
596 )))
597 });
598 async move {
599 Ok(future
600 .await?
601 .map(|result| result.map(LanguageModelCompletionEvent::Text))
602 .boxed())
603 }
604 .boxed()
605 }
606 CloudModel::Zed(model) => {
607 let client = self.client.clone();
608 let mut request = request.into_open_ai(model.id().into(), None);
609 request.max_tokens = Some(4000);
610 let llm_api_token = self.llm_api_token.clone();
611 let future = self.request_limiter.stream(async move {
612 let response = Self::perform_llm_completion(
613 client.clone(),
614 llm_api_token,
615 PerformCompletionParams {
616 provider: client::LanguageModelProvider::Zed,
617 model: request.model.clone(),
618 provider_request: RawValue::from_string(serde_json::to_string(
619 &request,
620 )?)?,
621 },
622 None,
623 )
624 .await?;
625 Ok(open_ai::extract_text_from_events(response_lines(response)))
626 });
627 async move {
628 Ok(future
629 .await?
630 .map(|result| result.map(LanguageModelCompletionEvent::Text))
631 .boxed())
632 }
633 .boxed()
634 }
635 }
636 }
637
638 fn use_any_tool(
639 &self,
640 request: LanguageModelRequest,
641 tool_name: String,
642 tool_description: String,
643 input_schema: serde_json::Value,
644 _cx: &AsyncAppContext,
645 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
646 let client = self.client.clone();
647 let llm_api_token = self.llm_api_token.clone();
648
649 match &self.model {
650 CloudModel::Anthropic(model) => {
651 let mut request = request.into_anthropic(
652 model.tool_model_id().into(),
653 model.default_temperature(),
654 model.max_output_tokens(),
655 );
656 request.tool_choice = Some(anthropic::ToolChoice::Tool {
657 name: tool_name.clone(),
658 });
659 request.tools = vec![anthropic::Tool {
660 name: tool_name.clone(),
661 description: tool_description,
662 input_schema,
663 }];
664
665 self.request_limiter
666 .run(async move {
667 let response = Self::perform_llm_completion(
668 client.clone(),
669 llm_api_token,
670 PerformCompletionParams {
671 provider: client::LanguageModelProvider::Anthropic,
672 model: request.model.clone(),
673 provider_request: RawValue::from_string(serde_json::to_string(
674 &request,
675 )?)?,
676 },
677 None,
678 )
679 .await?;
680
681 Ok(anthropic::extract_tool_args_from_events(
682 tool_name,
683 Box::pin(response_lines(response)),
684 )
685 .await?
686 .boxed())
687 })
688 .boxed()
689 }
690 CloudModel::OpenAi(model) => {
691 let mut request =
692 request.into_open_ai(model.id().into(), model.max_output_tokens());
693 request.tool_choice = Some(open_ai::ToolChoice::Other(
694 open_ai::ToolDefinition::Function {
695 function: open_ai::FunctionDefinition {
696 name: tool_name.clone(),
697 description: None,
698 parameters: None,
699 },
700 },
701 ));
702 request.tools = vec![open_ai::ToolDefinition::Function {
703 function: open_ai::FunctionDefinition {
704 name: tool_name.clone(),
705 description: Some(tool_description),
706 parameters: Some(input_schema),
707 },
708 }];
709
710 self.request_limiter
711 .run(async move {
712 let response = Self::perform_llm_completion(
713 client.clone(),
714 llm_api_token,
715 PerformCompletionParams {
716 provider: client::LanguageModelProvider::OpenAi,
717 model: request.model.clone(),
718 provider_request: RawValue::from_string(serde_json::to_string(
719 &request,
720 )?)?,
721 },
722 None,
723 )
724 .await?;
725
726 Ok(open_ai::extract_tool_args_from_events(
727 tool_name,
728 Box::pin(response_lines(response)),
729 )
730 .await?
731 .boxed())
732 })
733 .boxed()
734 }
735 CloudModel::Google(_) => {
736 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
737 }
738 CloudModel::Zed(model) => {
739 // All Zed models are OpenAI-based at the time of writing.
740 let mut request = request.into_open_ai(model.id().into(), None);
741 request.tool_choice = Some(open_ai::ToolChoice::Other(
742 open_ai::ToolDefinition::Function {
743 function: open_ai::FunctionDefinition {
744 name: tool_name.clone(),
745 description: None,
746 parameters: None,
747 },
748 },
749 ));
750 request.tools = vec![open_ai::ToolDefinition::Function {
751 function: open_ai::FunctionDefinition {
752 name: tool_name.clone(),
753 description: Some(tool_description),
754 parameters: Some(input_schema),
755 },
756 }];
757
758 self.request_limiter
759 .run(async move {
760 let response = Self::perform_llm_completion(
761 client.clone(),
762 llm_api_token,
763 PerformCompletionParams {
764 provider: client::LanguageModelProvider::Zed,
765 model: request.model.clone(),
766 provider_request: RawValue::from_string(serde_json::to_string(
767 &request,
768 )?)?,
769 },
770 None,
771 )
772 .await?;
773
774 Ok(open_ai::extract_tool_args_from_events(
775 tool_name,
776 Box::pin(response_lines(response)),
777 )
778 .await?
779 .boxed())
780 })
781 .boxed()
782 }
783 }
784 }
785}
786
787fn response_lines<T: DeserializeOwned>(
788 response: Response<AsyncBody>,
789) -> impl Stream<Item = Result<T>> {
790 futures::stream::try_unfold(
791 (String::new(), BufReader::new(response.into_body())),
792 move |(mut line, mut body)| async {
793 match body.read_line(&mut line).await {
794 Ok(0) => Ok(None),
795 Ok(_) => {
796 let event: T = serde_json::from_str(&line)?;
797 line.clear();
798 Ok(Some((event, (line, body))))
799 }
800 Err(e) => Err(e.into()),
801 }
802 },
803 )
804}
805
806impl LlmApiToken {
807 async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
808 let lock = self.0.upgradable_read().await;
809 if let Some(token) = lock.as_ref() {
810 Ok(token.to_string())
811 } else {
812 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, client).await
813 }
814 }
815
816 async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
817 Self::fetch(self.0.write().await, client).await
818 }
819
820 async fn fetch<'a>(
821 mut lock: RwLockWriteGuard<'a, Option<String>>,
822 client: &Arc<Client>,
823 ) -> Result<String> {
824 let response = client.request(proto::GetLlmToken {}).await?;
825 *lock = Some(response.token.clone());
826 Ok(response.token.clone())
827 }
828}
829
830struct ConfigurationView {
831 state: gpui::Model<State>,
832}
833
834impl ConfigurationView {
835 fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
836 self.state.update(cx, |state, cx| {
837 state.authenticate(cx).detach_and_log_err(cx);
838 });
839 cx.notify();
840 }
841
842 fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
843 if self.state.read(cx).has_accepted_terms_of_service(cx) {
844 return None;
845 }
846
847 let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
848
849 let terms_button = Button::new("terms_of_service", "Terms of Service")
850 .style(ButtonStyle::Subtle)
851 .icon(IconName::ExternalLink)
852 .icon_color(Color::Muted)
853 .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
854
855 let text =
856 "In order to use Zed AI, please read and accept our terms and conditions to continue:";
857
858 let form = v_flex()
859 .gap_2()
860 .child(Label::new("Terms and Conditions"))
861 .child(Label::new(text))
862 .child(h_flex().justify_center().child(terms_button))
863 .child(
864 h_flex().justify_center().child(
865 Button::new("accept_terms", "I've read and accept the terms of service")
866 .style(ButtonStyle::Tinted(TintColor::Accent))
867 .disabled(accept_terms_disabled)
868 .on_click({
869 let state = self.state.downgrade();
870 move |_, cx| {
871 state
872 .update(cx, |state, cx| state.accept_terms_of_service(cx))
873 .ok();
874 }
875 }),
876 ),
877 );
878
879 Some(form.into_any())
880 }
881}
882
883impl Render for ConfigurationView {
884 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
885 const ZED_AI_URL: &str = "https://zed.dev/ai";
886 const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
887
888 let is_connected = !self.state.read(cx).is_signed_out();
889 let plan = self.state.read(cx).user_store.read(cx).current_plan();
890 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
891
892 let is_pro = plan == Some(proto::Plan::ZedPro);
893 let subscription_text = Label::new(if is_pro {
894 "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
895 } else {
896 "You have basic access to models from Anthropic through the Zed AI Free plan."
897 });
898 let manage_subscription_button = if is_pro {
899 Some(
900 h_flex().child(
901 Button::new("manage_settings", "Manage Subscription")
902 .style(ButtonStyle::Tinted(TintColor::Accent))
903 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
904 ),
905 )
906 } else if cx.has_flag::<ZedPro>() {
907 Some(
908 h_flex()
909 .gap_2()
910 .child(
911 Button::new("learn_more", "Learn more")
912 .style(ButtonStyle::Subtle)
913 .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
914 )
915 .child(
916 Button::new("upgrade", "Upgrade")
917 .style(ButtonStyle::Subtle)
918 .color(Color::Accent)
919 .on_click(cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL))),
920 ),
921 )
922 } else {
923 None
924 };
925
926 if is_connected {
927 v_flex()
928 .gap_3()
929 .max_w_4_5()
930 .children(self.render_accept_terms(cx))
931 .when(has_accepted_terms, |this| {
932 this.child(subscription_text)
933 .children(manage_subscription_button)
934 })
935 } else {
936 v_flex()
937 .gap_6()
938 .child(Label::new("Use the zed.dev to access language models."))
939 .child(
940 v_flex()
941 .gap_2()
942 .child(
943 Button::new("sign_in", "Sign in")
944 .icon_color(Color::Muted)
945 .icon(IconName::Github)
946 .icon_position(IconPosition::Start)
947 .style(ButtonStyle::Filled)
948 .full_width()
949 .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
950 )
951 .child(
952 div().flex().w_full().items_center().child(
953 Label::new("Sign in to enable collaboration.")
954 .color(Color::Muted)
955 .size(LabelSize::Small),
956 ),
957 ),
958 )
959 }
960 }
961}