1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use chrono::{DateTime, Utc};
4use client::{Client, ModelRequestUsage, UserStore, zed_urls};
5use futures::{
6 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
7};
8use google_ai::GoogleModelMode;
9use gpui::{
10 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
11};
12use http_client::http::{HeaderMap, HeaderValue};
13use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
14use language_model::{
15 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
16 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
17 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
18 LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
19 LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
20 ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
21};
22use proto::Plan;
23use release_channel::AppVersion;
24use schemars::JsonSchema;
25use serde::{Deserialize, Serialize, de::DeserializeOwned};
26use settings::SettingsStore;
27use smol::io::{AsyncReadExt, BufReader};
28use std::pin::Pin;
29use std::str::FromStr as _;
30use std::sync::Arc;
31use std::time::Duration;
32use thiserror::Error;
33use ui::{TintColor, prelude::*};
34use util::{ResultExt as _, maybe};
35use zed_llm_client::{
36 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
37 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
38 ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
39 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
40 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
41};
42
43use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
44use crate::provider::google::{GoogleEventMapper, into_google};
45use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
46
47const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
48const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
49
50#[derive(Default, Clone, Debug, PartialEq)]
51pub struct ZedDotDevSettings {
52 pub available_models: Vec<AvailableModel>,
53}
54
55#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
56#[serde(rename_all = "lowercase")]
57pub enum AvailableProvider {
58 Anthropic,
59 OpenAi,
60 Google,
61}
62
63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
64pub struct AvailableModel {
65 /// The provider of the language model.
66 pub provider: AvailableProvider,
67 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
68 pub name: String,
69 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
70 pub display_name: Option<String>,
71 /// The size of the context window, indicating the maximum number of tokens the model can process.
72 pub max_tokens: usize,
73 /// The maximum number of output tokens allowed by the model.
74 pub max_output_tokens: Option<u64>,
75 /// The maximum number of completion tokens allowed by the model (o1-* only)
76 pub max_completion_tokens: Option<u64>,
77 /// Override this model with a different Anthropic model for tool calls.
78 pub tool_override: Option<String>,
79 /// Indicates whether this custom model supports caching.
80 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
81 /// The default temperature to use for this model.
82 pub default_temperature: Option<f32>,
83 /// Any extra beta headers to provide when using the model.
84 #[serde(default)]
85 pub extra_beta_headers: Vec<String>,
86 /// The model's mode (e.g. thinking)
87 pub mode: Option<ModelMode>,
88}
89
90#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
91#[serde(tag = "type", rename_all = "lowercase")]
92pub enum ModelMode {
93 #[default]
94 Default,
95 Thinking {
96 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
97 budget_tokens: Option<u32>,
98 },
99}
100
101impl From<ModelMode> for AnthropicModelMode {
102 fn from(value: ModelMode) -> Self {
103 match value {
104 ModelMode::Default => AnthropicModelMode::Default,
105 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
106 }
107 }
108}
109
110pub struct CloudLanguageModelProvider {
111 client: Arc<Client>,
112 state: gpui::Entity<State>,
113 _maintain_client_status: Task<()>,
114}
115
116pub struct State {
117 client: Arc<Client>,
118 llm_api_token: LlmApiToken,
119 user_store: Entity<UserStore>,
120 status: client::Status,
121 accept_terms_of_service_task: Option<Task<Result<()>>>,
122 models: Vec<Arc<zed_llm_client::LanguageModel>>,
123 default_model: Option<Arc<zed_llm_client::LanguageModel>>,
124 default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
125 recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
126 _fetch_models_task: Task<()>,
127 _settings_subscription: Subscription,
128 _llm_token_subscription: Subscription,
129}
130
131impl State {
132 fn new(
133 client: Arc<Client>,
134 user_store: Entity<UserStore>,
135 status: client::Status,
136 cx: &mut Context<Self>,
137 ) -> Self {
138 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
139
140 Self {
141 client: client.clone(),
142 llm_api_token: LlmApiToken::default(),
143 user_store,
144 status,
145 accept_terms_of_service_task: None,
146 models: Vec::new(),
147 default_model: None,
148 default_fast_model: None,
149 recommended_models: Vec::new(),
150 _fetch_models_task: cx.spawn(async move |this, cx| {
151 maybe!(async move {
152 let (client, llm_api_token) = this
153 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
154
155 loop {
156 let status = this.read_with(cx, |this, _cx| this.status)?;
157 if matches!(status, client::Status::Connected { .. }) {
158 break;
159 }
160
161 cx.background_executor()
162 .timer(Duration::from_millis(100))
163 .await;
164 }
165
166 let response = Self::fetch_models(client, llm_api_token).await?;
167 cx.update(|cx| {
168 this.update(cx, |this, cx| {
169 let mut models = Vec::new();
170
171 for model in response.models {
172 models.push(Arc::new(model.clone()));
173
174 // Right now we represent thinking variants of models as separate models on the client,
175 // so we need to insert variants for any model that supports thinking.
176 if model.supports_thinking {
177 models.push(Arc::new(zed_llm_client::LanguageModel {
178 id: zed_llm_client::LanguageModelId(
179 format!("{}-thinking", model.id).into(),
180 ),
181 display_name: format!("{} Thinking", model.display_name),
182 ..model
183 }));
184 }
185 }
186
187 this.default_model = models
188 .iter()
189 .find(|model| model.id == response.default_model)
190 .cloned();
191 this.default_fast_model = models
192 .iter()
193 .find(|model| model.id == response.default_fast_model)
194 .cloned();
195 this.recommended_models = response
196 .recommended_models
197 .iter()
198 .filter_map(|id| models.iter().find(|model| &model.id == id))
199 .cloned()
200 .collect();
201 this.models = models;
202 cx.notify();
203 })
204 })??;
205
206 anyhow::Ok(())
207 })
208 .await
209 .context("failed to fetch Zed models")
210 .log_err();
211 }),
212 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
213 cx.notify();
214 }),
215 _llm_token_subscription: cx.subscribe(
216 &refresh_llm_token_listener,
217 |this, _listener, _event, cx| {
218 let client = this.client.clone();
219 let llm_api_token = this.llm_api_token.clone();
220 cx.spawn(async move |_this, _cx| {
221 llm_api_token.refresh(&client).await?;
222 anyhow::Ok(())
223 })
224 .detach_and_log_err(cx);
225 },
226 ),
227 }
228 }
229
230 fn is_signed_out(&self) -> bool {
231 self.status.is_signed_out()
232 }
233
234 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
235 let client = self.client.clone();
236 cx.spawn(async move |state, cx| {
237 client
238 .authenticate_and_connect(true, &cx)
239 .await
240 .into_response()?;
241 state.update(cx, |_, cx| cx.notify())
242 })
243 }
244
245 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
246 self.user_store
247 .read(cx)
248 .current_user_has_accepted_terms()
249 .unwrap_or(false)
250 }
251
252 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
253 let user_store = self.user_store.clone();
254 self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| {
255 let _ = user_store
256 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
257 .await;
258 this.update(cx, |this, cx| {
259 this.accept_terms_of_service_task = None;
260 cx.notify()
261 })
262 }));
263 }
264
265 async fn fetch_models(
266 client: Arc<Client>,
267 llm_api_token: LlmApiToken,
268 ) -> Result<ListModelsResponse> {
269 let http_client = &client.http_client();
270 let token = llm_api_token.acquire(&client).await?;
271
272 let request = http_client::Request::builder()
273 .method(Method::GET)
274 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
275 .header("Authorization", format!("Bearer {token}"))
276 .body(AsyncBody::empty())?;
277 let mut response = http_client
278 .send(request)
279 .await
280 .context("failed to send list models request")?;
281
282 if response.status().is_success() {
283 let mut body = String::new();
284 response.body_mut().read_to_string(&mut body).await?;
285 return Ok(serde_json::from_str(&body)?);
286 } else {
287 let mut body = String::new();
288 response.body_mut().read_to_string(&mut body).await?;
289 anyhow::bail!(
290 "error listing models.\nStatus: {:?}\nBody: {body}",
291 response.status(),
292 );
293 }
294 }
295}
296
297impl CloudLanguageModelProvider {
298 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
299 let mut status_rx = client.status();
300 let status = *status_rx.borrow();
301
302 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
303
304 let state_ref = state.downgrade();
305 let maintain_client_status = cx.spawn(async move |cx| {
306 while let Some(status) = status_rx.next().await {
307 if let Some(this) = state_ref.upgrade() {
308 _ = this.update(cx, |this, cx| {
309 if this.status != status {
310 this.status = status;
311 cx.notify();
312 }
313 });
314 } else {
315 break;
316 }
317 }
318 });
319
320 Self {
321 client,
322 state: state.clone(),
323 _maintain_client_status: maintain_client_status,
324 }
325 }
326
327 fn create_language_model(
328 &self,
329 model: Arc<zed_llm_client::LanguageModel>,
330 llm_api_token: LlmApiToken,
331 ) -> Arc<dyn LanguageModel> {
332 Arc::new(CloudLanguageModel {
333 id: LanguageModelId(SharedString::from(model.id.0.clone())),
334 model,
335 llm_api_token: llm_api_token.clone(),
336 client: self.client.clone(),
337 request_limiter: RateLimiter::new(4),
338 })
339 }
340}
341
342impl LanguageModelProviderState for CloudLanguageModelProvider {
343 type ObservableEntity = State;
344
345 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
346 Some(self.state.clone())
347 }
348}
349
350impl LanguageModelProvider for CloudLanguageModelProvider {
351 fn id(&self) -> LanguageModelProviderId {
352 PROVIDER_ID
353 }
354
355 fn name(&self) -> LanguageModelProviderName {
356 PROVIDER_NAME
357 }
358
359 fn icon(&self) -> IconName {
360 IconName::AiZed
361 }
362
363 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
364 let default_model = self.state.read(cx).default_model.clone()?;
365 let llm_api_token = self.state.read(cx).llm_api_token.clone();
366 Some(self.create_language_model(default_model, llm_api_token))
367 }
368
369 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
370 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
371 let llm_api_token = self.state.read(cx).llm_api_token.clone();
372 Some(self.create_language_model(default_fast_model, llm_api_token))
373 }
374
375 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
376 let llm_api_token = self.state.read(cx).llm_api_token.clone();
377 self.state
378 .read(cx)
379 .recommended_models
380 .iter()
381 .cloned()
382 .map(|model| self.create_language_model(model, llm_api_token.clone()))
383 .collect()
384 }
385
386 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
387 let llm_api_token = self.state.read(cx).llm_api_token.clone();
388 self.state
389 .read(cx)
390 .models
391 .iter()
392 .cloned()
393 .map(|model| self.create_language_model(model, llm_api_token.clone()))
394 .collect()
395 }
396
397 fn is_authenticated(&self, cx: &App) -> bool {
398 let state = self.state.read(cx);
399 !state.is_signed_out() && state.has_accepted_terms_of_service(cx)
400 }
401
402 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
403 Task::ready(Ok(()))
404 }
405
406 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
407 cx.new(|_| ConfigurationView::new(self.state.clone()))
408 .into()
409 }
410
411 fn must_accept_terms(&self, cx: &App) -> bool {
412 !self.state.read(cx).has_accepted_terms_of_service(cx)
413 }
414
415 fn render_accept_terms(
416 &self,
417 view: LanguageModelProviderTosView,
418 cx: &mut App,
419 ) -> Option<AnyElement> {
420 let state = self.state.read(cx);
421 if state.has_accepted_terms_of_service(cx) {
422 return None;
423 }
424 Some(
425 render_accept_terms(view, state.accept_terms_of_service_task.is_some(), {
426 let state = self.state.clone();
427 move |_window, cx| {
428 state.update(cx, |state, cx| state.accept_terms_of_service(cx));
429 }
430 })
431 .into_any_element(),
432 )
433 }
434
435 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
436 Task::ready(Ok(()))
437 }
438}
439
440fn render_accept_terms(
441 view_kind: LanguageModelProviderTosView,
442 accept_terms_of_service_in_progress: bool,
443 accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static,
444) -> impl IntoElement {
445 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
446 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState);
447
448 let terms_button = Button::new("terms_of_service", "Terms of Service")
449 .style(ButtonStyle::Subtle)
450 .icon(IconName::ArrowUpRight)
451 .icon_color(Color::Muted)
452 .icon_size(IconSize::XSmall)
453 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
454 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
455
456 let button_container = h_flex().child(
457 Button::new("accept_terms", "I accept the Terms of Service")
458 .when(!thread_empty_state, |this| {
459 this.full_width()
460 .style(ButtonStyle::Tinted(TintColor::Accent))
461 .icon(IconName::Check)
462 .icon_position(IconPosition::Start)
463 .icon_size(IconSize::Small)
464 })
465 .when(thread_empty_state, |this| {
466 this.style(ButtonStyle::Tinted(TintColor::Warning))
467 .label_size(LabelSize::Small)
468 })
469 .disabled(accept_terms_of_service_in_progress)
470 .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)),
471 );
472
473 if thread_empty_state {
474 h_flex()
475 .w_full()
476 .flex_wrap()
477 .justify_between()
478 .child(
479 h_flex()
480 .child(
481 Label::new("To start using Zed AI, please read and accept the")
482 .size(LabelSize::Small),
483 )
484 .child(terms_button),
485 )
486 .child(button_container)
487 } else {
488 v_flex()
489 .w_full()
490 .gap_2()
491 .child(
492 h_flex()
493 .flex_wrap()
494 .when(thread_fresh_start, |this| this.justify_center())
495 .child(Label::new(
496 "To start using Zed AI, please read and accept the",
497 ))
498 .child(terms_button),
499 )
500 .child({
501 match view_kind {
502 LanguageModelProviderTosView::PromptEditorPopup => {
503 button_container.w_full().justify_end()
504 }
505 LanguageModelProviderTosView::Configuration => {
506 button_container.w_full().justify_start()
507 }
508 LanguageModelProviderTosView::ThreadFreshStart => {
509 button_container.w_full().justify_center()
510 }
511 LanguageModelProviderTosView::ThreadEmptyState => div().w_0(),
512 }
513 })
514 }
515}
516
517pub struct CloudLanguageModel {
518 id: LanguageModelId,
519 model: Arc<zed_llm_client::LanguageModel>,
520 llm_api_token: LlmApiToken,
521 client: Arc<Client>,
522 request_limiter: RateLimiter,
523}
524
525struct PerformLlmCompletionResponse {
526 response: Response<AsyncBody>,
527 usage: Option<ModelRequestUsage>,
528 tool_use_limit_reached: bool,
529 includes_status_messages: bool,
530}
531
532impl CloudLanguageModel {
533 async fn perform_llm_completion(
534 client: Arc<Client>,
535 llm_api_token: LlmApiToken,
536 app_version: Option<SemanticVersion>,
537 body: CompletionBody,
538 ) -> Result<PerformLlmCompletionResponse> {
539 let http_client = &client.http_client();
540
541 let mut token = llm_api_token.acquire(&client).await?;
542 let mut refreshed_token = false;
543
544 loop {
545 let request_builder = http_client::Request::builder()
546 .method(Method::POST)
547 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
548 let request_builder = if let Some(app_version) = app_version {
549 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
550 } else {
551 request_builder
552 };
553
554 let request = request_builder
555 .header("Content-Type", "application/json")
556 .header("Authorization", format!("Bearer {token}"))
557 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
558 .body(serde_json::to_string(&body)?.into())?;
559 let mut response = http_client.send(request).await?;
560 let status = response.status();
561 if status.is_success() {
562 let includes_status_messages = response
563 .headers()
564 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
565 .is_some();
566
567 let tool_use_limit_reached = response
568 .headers()
569 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
570 .is_some();
571
572 let usage = if includes_status_messages {
573 None
574 } else {
575 ModelRequestUsage::from_headers(response.headers()).ok()
576 };
577
578 return Ok(PerformLlmCompletionResponse {
579 response,
580 usage,
581 includes_status_messages,
582 tool_use_limit_reached,
583 });
584 }
585
586 if !refreshed_token
587 && response
588 .headers()
589 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
590 .is_some()
591 {
592 token = llm_api_token.refresh(&client).await?;
593 refreshed_token = true;
594 continue;
595 }
596
597 if status == StatusCode::FORBIDDEN
598 && response
599 .headers()
600 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
601 .is_some()
602 {
603 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
604 .headers()
605 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
606 .and_then(|resource| resource.to_str().ok())
607 {
608 if let Some(plan) = response
609 .headers()
610 .get(CURRENT_PLAN_HEADER_NAME)
611 .and_then(|plan| plan.to_str().ok())
612 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
613 {
614 let plan = match plan {
615 zed_llm_client::Plan::ZedFree => Plan::Free,
616 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
617 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
618 };
619 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
620 }
621 }
622 } else if status == StatusCode::PAYMENT_REQUIRED {
623 return Err(anyhow!(PaymentRequiredError));
624 }
625
626 let mut body = String::new();
627 let headers = response.headers().clone();
628 response.body_mut().read_to_string(&mut body).await?;
629 return Err(anyhow!(ApiError {
630 status,
631 body,
632 headers
633 }));
634 }
635 }
636}
637
638#[derive(Debug, Error)]
639#[error("cloud language model request failed with status {status}: {body}")]
640struct ApiError {
641 status: StatusCode,
642 body: String,
643 headers: HeaderMap<HeaderValue>,
644}
645
646impl From<ApiError> for LanguageModelCompletionError {
647 fn from(error: ApiError) -> Self {
648 let retry_after = None;
649 LanguageModelCompletionError::from_http_status(
650 PROVIDER_NAME,
651 error.status,
652 error.body,
653 retry_after,
654 )
655 }
656}
657
658impl LanguageModel for CloudLanguageModel {
659 fn id(&self) -> LanguageModelId {
660 self.id.clone()
661 }
662
663 fn name(&self) -> LanguageModelName {
664 LanguageModelName::from(self.model.display_name.clone())
665 }
666
667 fn provider_id(&self) -> LanguageModelProviderId {
668 PROVIDER_ID
669 }
670
671 fn provider_name(&self) -> LanguageModelProviderName {
672 PROVIDER_NAME
673 }
674
675 fn upstream_provider_id(&self) -> LanguageModelProviderId {
676 use zed_llm_client::LanguageModelProvider::*;
677 match self.model.provider {
678 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
679 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
680 Google => language_model::GOOGLE_PROVIDER_ID,
681 }
682 }
683
684 fn upstream_provider_name(&self) -> LanguageModelProviderName {
685 use zed_llm_client::LanguageModelProvider::*;
686 match self.model.provider {
687 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
688 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
689 Google => language_model::GOOGLE_PROVIDER_NAME,
690 }
691 }
692
693 fn supports_tools(&self) -> bool {
694 self.model.supports_tools
695 }
696
697 fn supports_images(&self) -> bool {
698 self.model.supports_images
699 }
700
701 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
702 match choice {
703 LanguageModelToolChoice::Auto
704 | LanguageModelToolChoice::Any
705 | LanguageModelToolChoice::None => true,
706 }
707 }
708
709 fn supports_burn_mode(&self) -> bool {
710 self.model.supports_max_mode
711 }
712
713 fn telemetry_id(&self) -> String {
714 format!("zed.dev/{}", self.model.id)
715 }
716
717 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
718 match self.model.provider {
719 zed_llm_client::LanguageModelProvider::Anthropic
720 | zed_llm_client::LanguageModelProvider::OpenAi => {
721 LanguageModelToolSchemaFormat::JsonSchema
722 }
723 zed_llm_client::LanguageModelProvider::Google => {
724 LanguageModelToolSchemaFormat::JsonSchemaSubset
725 }
726 }
727 }
728
729 fn max_token_count(&self) -> u64 {
730 self.model.max_token_count as u64
731 }
732
733 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
734 match &self.model.provider {
735 zed_llm_client::LanguageModelProvider::Anthropic => {
736 Some(LanguageModelCacheConfiguration {
737 min_total_token: 2_048,
738 should_speculate: true,
739 max_cache_anchors: 4,
740 })
741 }
742 zed_llm_client::LanguageModelProvider::OpenAi
743 | zed_llm_client::LanguageModelProvider::Google => None,
744 }
745 }
746
747 fn count_tokens(
748 &self,
749 request: LanguageModelRequest,
750 cx: &App,
751 ) -> BoxFuture<'static, Result<u64>> {
752 match self.model.provider {
753 zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
754 zed_llm_client::LanguageModelProvider::OpenAi => {
755 let model = match open_ai::Model::from_id(&self.model.id.0) {
756 Ok(model) => model,
757 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
758 };
759 count_open_ai_tokens(request, model, cx)
760 }
761 zed_llm_client::LanguageModelProvider::Google => {
762 let client = self.client.clone();
763 let llm_api_token = self.llm_api_token.clone();
764 let model_id = self.model.id.to_string();
765 let generate_content_request =
766 into_google(request, model_id.clone(), GoogleModelMode::Default);
767 async move {
768 let http_client = &client.http_client();
769 let token = llm_api_token.acquire(&client).await?;
770
771 let request_body = CountTokensBody {
772 provider: zed_llm_client::LanguageModelProvider::Google,
773 model: model_id,
774 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
775 generate_content_request,
776 })?,
777 };
778 let request = http_client::Request::builder()
779 .method(Method::POST)
780 .uri(
781 http_client
782 .build_zed_llm_url("/count_tokens", &[])?
783 .as_ref(),
784 )
785 .header("Content-Type", "application/json")
786 .header("Authorization", format!("Bearer {token}"))
787 .body(serde_json::to_string(&request_body)?.into())?;
788 let mut response = http_client.send(request).await?;
789 let status = response.status();
790 let headers = response.headers().clone();
791 let mut response_body = String::new();
792 response
793 .body_mut()
794 .read_to_string(&mut response_body)
795 .await?;
796
797 if status.is_success() {
798 let response_body: CountTokensResponse =
799 serde_json::from_str(&response_body)?;
800
801 Ok(response_body.tokens as u64)
802 } else {
803 Err(anyhow!(ApiError {
804 status,
805 body: response_body,
806 headers
807 }))
808 }
809 }
810 .boxed()
811 }
812 }
813 }
814
815 fn stream_completion(
816 &self,
817 request: LanguageModelRequest,
818 cx: &AsyncApp,
819 ) -> BoxFuture<
820 'static,
821 Result<
822 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
823 LanguageModelCompletionError,
824 >,
825 > {
826 let thread_id = request.thread_id.clone();
827 let prompt_id = request.prompt_id.clone();
828 let intent = request.intent;
829 let mode = request.mode;
830 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
831 match self.model.provider {
832 zed_llm_client::LanguageModelProvider::Anthropic => {
833 let request = into_anthropic(
834 request,
835 self.model.id.to_string(),
836 1.0,
837 self.model.max_output_tokens as u64,
838 if self.model.id.0.ends_with("-thinking") {
839 AnthropicModelMode::Thinking {
840 budget_tokens: Some(4_096),
841 }
842 } else {
843 AnthropicModelMode::Default
844 },
845 );
846 let client = self.client.clone();
847 let llm_api_token = self.llm_api_token.clone();
848 let future = self.request_limiter.stream(async move {
849 let PerformLlmCompletionResponse {
850 response,
851 usage,
852 includes_status_messages,
853 tool_use_limit_reached,
854 } = Self::perform_llm_completion(
855 client.clone(),
856 llm_api_token,
857 app_version,
858 CompletionBody {
859 thread_id,
860 prompt_id,
861 intent,
862 mode,
863 provider: zed_llm_client::LanguageModelProvider::Anthropic,
864 model: request.model.clone(),
865 provider_request: serde_json::to_value(&request)
866 .map_err(|e| anyhow!(e))?,
867 },
868 )
869 .await
870 .map_err(|err| match err.downcast::<ApiError>() {
871 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
872 Err(err) => anyhow!(err),
873 })?;
874
875 let mut mapper = AnthropicEventMapper::new();
876 Ok(map_cloud_completion_events(
877 Box::pin(
878 response_lines(response, includes_status_messages)
879 .chain(usage_updated_event(usage))
880 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
881 ),
882 move |event| mapper.map_event(event),
883 ))
884 });
885 async move { Ok(future.await?.boxed()) }.boxed()
886 }
887 zed_llm_client::LanguageModelProvider::OpenAi => {
888 let client = self.client.clone();
889 let model = match open_ai::Model::from_id(&self.model.id.0) {
890 Ok(model) => model,
891 Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
892 };
893 let request = into_open_ai(
894 request,
895 model.id(),
896 model.supports_parallel_tool_calls(),
897 None,
898 );
899 let llm_api_token = self.llm_api_token.clone();
900 let future = self.request_limiter.stream(async move {
901 let PerformLlmCompletionResponse {
902 response,
903 usage,
904 includes_status_messages,
905 tool_use_limit_reached,
906 } = Self::perform_llm_completion(
907 client.clone(),
908 llm_api_token,
909 app_version,
910 CompletionBody {
911 thread_id,
912 prompt_id,
913 intent,
914 mode,
915 provider: zed_llm_client::LanguageModelProvider::OpenAi,
916 model: request.model.clone(),
917 provider_request: serde_json::to_value(&request)
918 .map_err(|e| anyhow!(e))?,
919 },
920 )
921 .await?;
922
923 let mut mapper = OpenAiEventMapper::new();
924 Ok(map_cloud_completion_events(
925 Box::pin(
926 response_lines(response, includes_status_messages)
927 .chain(usage_updated_event(usage))
928 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
929 ),
930 move |event| mapper.map_event(event),
931 ))
932 });
933 async move { Ok(future.await?.boxed()) }.boxed()
934 }
935 zed_llm_client::LanguageModelProvider::Google => {
936 let client = self.client.clone();
937 let request =
938 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
939 let llm_api_token = self.llm_api_token.clone();
940 let future = self.request_limiter.stream(async move {
941 let PerformLlmCompletionResponse {
942 response,
943 usage,
944 includes_status_messages,
945 tool_use_limit_reached,
946 } = Self::perform_llm_completion(
947 client.clone(),
948 llm_api_token,
949 app_version,
950 CompletionBody {
951 thread_id,
952 prompt_id,
953 intent,
954 mode,
955 provider: zed_llm_client::LanguageModelProvider::Google,
956 model: request.model.model_id.clone(),
957 provider_request: serde_json::to_value(&request)
958 .map_err(|e| anyhow!(e))?,
959 },
960 )
961 .await?;
962
963 let mut mapper = GoogleEventMapper::new();
964 Ok(map_cloud_completion_events(
965 Box::pin(
966 response_lines(response, includes_status_messages)
967 .chain(usage_updated_event(usage))
968 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
969 ),
970 move |event| mapper.map_event(event),
971 ))
972 });
973 async move { Ok(future.await?.boxed()) }.boxed()
974 }
975 }
976 }
977}
978
979#[derive(Serialize, Deserialize)]
980#[serde(rename_all = "snake_case")]
981pub enum CloudCompletionEvent<T> {
982 Status(CompletionRequestStatus),
983 Event(T),
984}
985
986fn map_cloud_completion_events<T, F>(
987 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
988 mut map_callback: F,
989) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
990where
991 T: DeserializeOwned + 'static,
992 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
993 + Send
994 + 'static,
995{
996 stream
997 .flat_map(move |event| {
998 futures::stream::iter(match event {
999 Err(error) => {
1000 vec![Err(LanguageModelCompletionError::from(error))]
1001 }
1002 Ok(CloudCompletionEvent::Status(event)) => {
1003 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1004 }
1005 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1006 })
1007 })
1008 .boxed()
1009}
1010
1011fn usage_updated_event<T>(
1012 usage: Option<ModelRequestUsage>,
1013) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1014 futures::stream::iter(usage.map(|usage| {
1015 Ok(CloudCompletionEvent::Status(
1016 CompletionRequestStatus::UsageUpdated {
1017 amount: usage.amount as usize,
1018 limit: usage.limit,
1019 },
1020 ))
1021 }))
1022}
1023
1024fn tool_use_limit_reached_event<T>(
1025 tool_use_limit_reached: bool,
1026) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1027 futures::stream::iter(tool_use_limit_reached.then(|| {
1028 Ok(CloudCompletionEvent::Status(
1029 CompletionRequestStatus::ToolUseLimitReached,
1030 ))
1031 }))
1032}
1033
1034fn response_lines<T: DeserializeOwned>(
1035 response: Response<AsyncBody>,
1036 includes_status_messages: bool,
1037) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1038 futures::stream::try_unfold(
1039 (String::new(), BufReader::new(response.into_body())),
1040 move |(mut line, mut body)| async move {
1041 match body.read_line(&mut line).await {
1042 Ok(0) => Ok(None),
1043 Ok(_) => {
1044 let event = if includes_status_messages {
1045 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1046 } else {
1047 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1048 };
1049
1050 line.clear();
1051 Ok(Some((event, (line, body))))
1052 }
1053 Err(e) => Err(e.into()),
1054 }
1055 },
1056 )
1057}
1058
1059#[derive(IntoElement, RegisterComponent)]
1060struct ZedAiConfiguration {
1061 is_connected: bool,
1062 plan: Option<proto::Plan>,
1063 subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1064 eligible_for_trial: bool,
1065 has_accepted_terms_of_service: bool,
1066 accept_terms_of_service_in_progress: bool,
1067 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1068 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1069}
1070
1071impl RenderOnce for ZedAiConfiguration {
1072 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1073 const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1074
1075 let is_pro = self.plan == Some(proto::Plan::ZedPro);
1076 let subscription_text = match (self.plan, self.subscription_period) {
1077 (Some(proto::Plan::ZedPro), Some(_)) => {
1078 "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1079 }
1080 (Some(proto::Plan::ZedProTrial), Some(_)) => {
1081 "You have access to Zed's hosted LLMs through your Zed Pro trial."
1082 }
1083 (Some(proto::Plan::Free), Some(_)) => {
1084 "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1085 }
1086 _ => {
1087 if self.eligible_for_trial {
1088 "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1089 } else {
1090 "Subscribe for access to Zed's hosted LLMs."
1091 }
1092 }
1093 };
1094 let manage_subscription_buttons = if is_pro {
1095 h_flex().child(
1096 Button::new("manage_settings", "Manage Subscription")
1097 .style(ButtonStyle::Tinted(TintColor::Accent))
1098 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1099 )
1100 } else {
1101 h_flex()
1102 .gap_2()
1103 .child(
1104 Button::new("learn_more", "Learn more")
1105 .style(ButtonStyle::Subtle)
1106 .on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)),
1107 )
1108 .child(
1109 Button::new(
1110 "upgrade",
1111 if self.plan.is_none() && self.eligible_for_trial {
1112 "Start Trial"
1113 } else {
1114 "Upgrade"
1115 },
1116 )
1117 .style(ButtonStyle::Subtle)
1118 .color(Color::Accent)
1119 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1120 )
1121 };
1122
1123 if self.is_connected {
1124 v_flex()
1125 .gap_3()
1126 .w_full()
1127 .when(!self.has_accepted_terms_of_service, |this| {
1128 this.child(render_accept_terms(
1129 LanguageModelProviderTosView::Configuration,
1130 self.accept_terms_of_service_in_progress,
1131 {
1132 let callback = self.accept_terms_of_service_callback.clone();
1133 move |window, cx| (callback)(window, cx)
1134 },
1135 ))
1136 })
1137 .when(self.has_accepted_terms_of_service, |this| {
1138 this.child(subscription_text)
1139 .child(manage_subscription_buttons)
1140 })
1141 } else {
1142 v_flex()
1143 .gap_2()
1144 .child(Label::new("Use Zed AI to access hosted language models."))
1145 .child(
1146 Button::new("sign_in", "Sign In")
1147 .icon_color(Color::Muted)
1148 .icon(IconName::Github)
1149 .icon_position(IconPosition::Start)
1150 .on_click({
1151 let callback = self.sign_in_callback.clone();
1152 move |_, window, cx| (callback)(window, cx)
1153 }),
1154 )
1155 }
1156 }
1157}
1158
1159struct ConfigurationView {
1160 state: Entity<State>,
1161 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1162 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1163}
1164
1165impl ConfigurationView {
1166 fn new(state: Entity<State>) -> Self {
1167 let accept_terms_of_service_callback = Arc::new({
1168 let state = state.clone();
1169 move |_window: &mut Window, cx: &mut App| {
1170 state.update(cx, |state, cx| {
1171 state.accept_terms_of_service(cx);
1172 });
1173 }
1174 });
1175
1176 let sign_in_callback = Arc::new({
1177 let state = state.clone();
1178 move |_window: &mut Window, cx: &mut App| {
1179 state.update(cx, |state, cx| {
1180 state.authenticate(cx).detach_and_log_err(cx);
1181 });
1182 }
1183 });
1184
1185 Self {
1186 state,
1187 accept_terms_of_service_callback,
1188 sign_in_callback,
1189 }
1190 }
1191}
1192
1193impl Render for ConfigurationView {
1194 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1195 let state = self.state.read(cx);
1196 let user_store = state.user_store.read(cx);
1197
1198 ZedAiConfiguration {
1199 is_connected: !state.is_signed_out(),
1200 plan: user_store.current_plan(),
1201 subscription_period: user_store.subscription_period(),
1202 eligible_for_trial: user_store.trial_started_at().is_none(),
1203 has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
1204 accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
1205 accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
1206 sign_in_callback: self.sign_in_callback.clone(),
1207 }
1208 }
1209}
1210
1211impl Component for ZedAiConfiguration {
1212 fn scope() -> ComponentScope {
1213 ComponentScope::Agent
1214 }
1215
1216 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1217 fn configuration(
1218 is_connected: bool,
1219 plan: Option<proto::Plan>,
1220 eligible_for_trial: bool,
1221 has_accepted_terms_of_service: bool,
1222 ) -> AnyElement {
1223 ZedAiConfiguration {
1224 is_connected,
1225 plan,
1226 subscription_period: plan
1227 .is_some()
1228 .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1229 eligible_for_trial,
1230 has_accepted_terms_of_service,
1231 accept_terms_of_service_in_progress: false,
1232 accept_terms_of_service_callback: Arc::new(|_, _| {}),
1233 sign_in_callback: Arc::new(|_, _| {}),
1234 }
1235 .into_any_element()
1236 }
1237
1238 Some(
1239 v_flex()
1240 .p_4()
1241 .gap_4()
1242 .children(vec![
1243 single_example("Not connected", configuration(false, None, false, true)),
1244 single_example(
1245 "Accept Terms of Service",
1246 configuration(true, None, true, false),
1247 ),
1248 single_example(
1249 "No Plan - Not eligible for trial",
1250 configuration(true, None, false, true),
1251 ),
1252 single_example(
1253 "No Plan - Eligible for trial",
1254 configuration(true, None, true, true),
1255 ),
1256 single_example(
1257 "Free Plan",
1258 configuration(true, Some(proto::Plan::Free), true, true),
1259 ),
1260 single_example(
1261 "Zed Pro Trial Plan",
1262 configuration(true, Some(proto::Plan::ZedProTrial), true, true),
1263 ),
1264 single_example(
1265 "Zed Pro Plan",
1266 configuration(true, Some(proto::Plan::ZedPro), true, true),
1267 ),
1268 ])
1269 .into_any_element(),
1270 )
1271 }
1272}