1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use chrono::{DateTime, Utc};
4use client::{Client, ModelRequestUsage, UserStore, zed_urls};
5use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag};
6use futures::{
7 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
8};
9use google_ai::GoogleModelMode;
10use gpui::{
11 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
12};
13use http_client::http::{HeaderMap, HeaderValue};
14use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
15use language_model::{
16 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
17 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
18 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
19 LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
20 LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
21 ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
22};
23use proto::Plan;
24use release_channel::AppVersion;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize, de::DeserializeOwned};
27use settings::SettingsStore;
28use smol::io::{AsyncReadExt, BufReader};
29use std::pin::Pin;
30use std::str::FromStr as _;
31use std::sync::Arc;
32use std::time::Duration;
33use thiserror::Error;
34use ui::{TintColor, prelude::*};
35use util::{ResultExt as _, maybe};
36use zed_llm_client::{
37 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
38 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
39 ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
40 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
41 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
42};
43
44use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
45use crate::provider::google::{GoogleEventMapper, into_google};
46use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
47
48const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
49const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
50
51#[derive(Default, Clone, Debug, PartialEq)]
52pub struct ZedDotDevSettings {
53 pub available_models: Vec<AvailableModel>,
54}
55
56#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
57#[serde(rename_all = "lowercase")]
58pub enum AvailableProvider {
59 Anthropic,
60 OpenAi,
61 Google,
62}
63
64#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
65pub struct AvailableModel {
66 /// The provider of the language model.
67 pub provider: AvailableProvider,
68 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
69 pub name: String,
70 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
71 pub display_name: Option<String>,
72 /// The size of the context window, indicating the maximum number of tokens the model can process.
73 pub max_tokens: usize,
74 /// The maximum number of output tokens allowed by the model.
75 pub max_output_tokens: Option<u64>,
76 /// The maximum number of completion tokens allowed by the model (o1-* only)
77 pub max_completion_tokens: Option<u64>,
78 /// Override this model with a different Anthropic model for tool calls.
79 pub tool_override: Option<String>,
80 /// Indicates whether this custom model supports caching.
81 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
82 /// The default temperature to use for this model.
83 pub default_temperature: Option<f32>,
84 /// Any extra beta headers to provide when using the model.
85 #[serde(default)]
86 pub extra_beta_headers: Vec<String>,
87 /// The model's mode (e.g. thinking)
88 pub mode: Option<ModelMode>,
89}
90
91#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
92#[serde(tag = "type", rename_all = "lowercase")]
93pub enum ModelMode {
94 #[default]
95 Default,
96 Thinking {
97 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
98 budget_tokens: Option<u32>,
99 },
100}
101
102impl From<ModelMode> for AnthropicModelMode {
103 fn from(value: ModelMode) -> Self {
104 match value {
105 ModelMode::Default => AnthropicModelMode::Default,
106 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
107 }
108 }
109}
110
111pub struct CloudLanguageModelProvider {
112 client: Arc<Client>,
113 state: gpui::Entity<State>,
114 _maintain_client_status: Task<()>,
115}
116
117pub struct State {
118 client: Arc<Client>,
119 llm_api_token: LlmApiToken,
120 user_store: Entity<UserStore>,
121 status: client::Status,
122 accept_terms_of_service_task: Option<Task<Result<()>>>,
123 models: Vec<Arc<zed_llm_client::LanguageModel>>,
124 default_model: Option<Arc<zed_llm_client::LanguageModel>>,
125 default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
126 recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
127 _fetch_models_task: Task<()>,
128 _settings_subscription: Subscription,
129 _llm_token_subscription: Subscription,
130}
131
132impl State {
133 fn new(
134 client: Arc<Client>,
135 user_store: Entity<UserStore>,
136 status: client::Status,
137 cx: &mut Context<Self>,
138 ) -> Self {
139 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
140 let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
141
142 Self {
143 client: client.clone(),
144 llm_api_token: LlmApiToken::default(),
145 user_store,
146 status,
147 accept_terms_of_service_task: None,
148 models: Vec::new(),
149 default_model: None,
150 default_fast_model: None,
151 recommended_models: Vec::new(),
152 _fetch_models_task: cx.spawn(async move |this, cx| {
153 maybe!(async move {
154 let (client, llm_api_token) = this
155 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
156
157 loop {
158 let status = this.read_with(cx, |this, _cx| this.status)?;
159 if matches!(status, client::Status::Connected { .. }) {
160 break;
161 }
162
163 cx.background_executor()
164 .timer(Duration::from_millis(100))
165 .await;
166 }
167
168 let response = Self::fetch_models(client, llm_api_token, use_cloud).await?;
169 cx.update(|cx| {
170 this.update(cx, |this, cx| {
171 let mut models = Vec::new();
172
173 for model in response.models {
174 models.push(Arc::new(model.clone()));
175
176 // Right now we represent thinking variants of models as separate models on the client,
177 // so we need to insert variants for any model that supports thinking.
178 if model.supports_thinking {
179 models.push(Arc::new(zed_llm_client::LanguageModel {
180 id: zed_llm_client::LanguageModelId(
181 format!("{}-thinking", model.id).into(),
182 ),
183 display_name: format!("{} Thinking", model.display_name),
184 ..model
185 }));
186 }
187 }
188
189 this.default_model = models
190 .iter()
191 .find(|model| model.id == response.default_model)
192 .cloned();
193 this.default_fast_model = models
194 .iter()
195 .find(|model| model.id == response.default_fast_model)
196 .cloned();
197 this.recommended_models = response
198 .recommended_models
199 .iter()
200 .filter_map(|id| models.iter().find(|model| &model.id == id))
201 .cloned()
202 .collect();
203 this.models = models;
204 cx.notify();
205 })
206 })??;
207
208 anyhow::Ok(())
209 })
210 .await
211 .context("failed to fetch Zed models")
212 .log_err();
213 }),
214 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
215 cx.notify();
216 }),
217 _llm_token_subscription: cx.subscribe(
218 &refresh_llm_token_listener,
219 |this, _listener, _event, cx| {
220 let client = this.client.clone();
221 let llm_api_token = this.llm_api_token.clone();
222 cx.spawn(async move |_this, _cx| {
223 llm_api_token.refresh(&client).await?;
224 anyhow::Ok(())
225 })
226 .detach_and_log_err(cx);
227 },
228 ),
229 }
230 }
231
232 fn is_signed_out(&self) -> bool {
233 self.status.is_signed_out()
234 }
235
236 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
237 let client = self.client.clone();
238 cx.spawn(async move |state, cx| {
239 client
240 .authenticate_and_connect(true, &cx)
241 .await
242 .into_response()?;
243 state.update(cx, |_, cx| cx.notify())
244 })
245 }
246
247 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
248 self.user_store
249 .read(cx)
250 .current_user_has_accepted_terms()
251 .unwrap_or(false)
252 }
253
254 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
255 let user_store = self.user_store.clone();
256 self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| {
257 let _ = user_store
258 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
259 .await;
260 this.update(cx, |this, cx| {
261 this.accept_terms_of_service_task = None;
262 cx.notify()
263 })
264 }));
265 }
266
267 async fn fetch_models(
268 client: Arc<Client>,
269 llm_api_token: LlmApiToken,
270 use_cloud: bool,
271 ) -> Result<ListModelsResponse> {
272 let http_client = &client.http_client();
273 let token = llm_api_token.acquire(&client).await?;
274
275 let request = http_client::Request::builder()
276 .method(Method::GET)
277 .uri(
278 http_client
279 .build_zed_llm_url("/models", &[], use_cloud)?
280 .as_ref(),
281 )
282 .header("Authorization", format!("Bearer {token}"))
283 .body(AsyncBody::empty())?;
284 let mut response = http_client
285 .send(request)
286 .await
287 .context("failed to send list models request")?;
288
289 if response.status().is_success() {
290 let mut body = String::new();
291 response.body_mut().read_to_string(&mut body).await?;
292 return Ok(serde_json::from_str(&body)?);
293 } else {
294 let mut body = String::new();
295 response.body_mut().read_to_string(&mut body).await?;
296 anyhow::bail!(
297 "error listing models.\nStatus: {:?}\nBody: {body}",
298 response.status(),
299 );
300 }
301 }
302}
303
304impl CloudLanguageModelProvider {
305 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
306 let mut status_rx = client.status();
307 let status = *status_rx.borrow();
308
309 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
310
311 let state_ref = state.downgrade();
312 let maintain_client_status = cx.spawn(async move |cx| {
313 while let Some(status) = status_rx.next().await {
314 if let Some(this) = state_ref.upgrade() {
315 _ = this.update(cx, |this, cx| {
316 if this.status != status {
317 this.status = status;
318 cx.notify();
319 }
320 });
321 } else {
322 break;
323 }
324 }
325 });
326
327 Self {
328 client,
329 state: state.clone(),
330 _maintain_client_status: maintain_client_status,
331 }
332 }
333
334 fn create_language_model(
335 &self,
336 model: Arc<zed_llm_client::LanguageModel>,
337 llm_api_token: LlmApiToken,
338 ) -> Arc<dyn LanguageModel> {
339 Arc::new(CloudLanguageModel {
340 id: LanguageModelId(SharedString::from(model.id.0.clone())),
341 model,
342 llm_api_token: llm_api_token.clone(),
343 client: self.client.clone(),
344 request_limiter: RateLimiter::new(4),
345 })
346 }
347}
348
349impl LanguageModelProviderState for CloudLanguageModelProvider {
350 type ObservableEntity = State;
351
352 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
353 Some(self.state.clone())
354 }
355}
356
357impl LanguageModelProvider for CloudLanguageModelProvider {
358 fn id(&self) -> LanguageModelProviderId {
359 PROVIDER_ID
360 }
361
362 fn name(&self) -> LanguageModelProviderName {
363 PROVIDER_NAME
364 }
365
366 fn icon(&self) -> IconName {
367 IconName::AiZed
368 }
369
370 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
371 let default_model = self.state.read(cx).default_model.clone()?;
372 let llm_api_token = self.state.read(cx).llm_api_token.clone();
373 Some(self.create_language_model(default_model, llm_api_token))
374 }
375
376 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
377 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
378 let llm_api_token = self.state.read(cx).llm_api_token.clone();
379 Some(self.create_language_model(default_fast_model, llm_api_token))
380 }
381
382 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
383 let llm_api_token = self.state.read(cx).llm_api_token.clone();
384 self.state
385 .read(cx)
386 .recommended_models
387 .iter()
388 .cloned()
389 .map(|model| self.create_language_model(model, llm_api_token.clone()))
390 .collect()
391 }
392
393 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
394 let llm_api_token = self.state.read(cx).llm_api_token.clone();
395 self.state
396 .read(cx)
397 .models
398 .iter()
399 .cloned()
400 .map(|model| self.create_language_model(model, llm_api_token.clone()))
401 .collect()
402 }
403
404 fn is_authenticated(&self, cx: &App) -> bool {
405 let state = self.state.read(cx);
406 !state.is_signed_out() && state.has_accepted_terms_of_service(cx)
407 }
408
409 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
410 Task::ready(Ok(()))
411 }
412
413 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
414 cx.new(|_| ConfigurationView::new(self.state.clone()))
415 .into()
416 }
417
418 fn must_accept_terms(&self, cx: &App) -> bool {
419 !self.state.read(cx).has_accepted_terms_of_service(cx)
420 }
421
422 fn render_accept_terms(
423 &self,
424 view: LanguageModelProviderTosView,
425 cx: &mut App,
426 ) -> Option<AnyElement> {
427 let state = self.state.read(cx);
428 if state.has_accepted_terms_of_service(cx) {
429 return None;
430 }
431 Some(
432 render_accept_terms(view, state.accept_terms_of_service_task.is_some(), {
433 let state = self.state.clone();
434 move |_window, cx| {
435 state.update(cx, |state, cx| state.accept_terms_of_service(cx));
436 }
437 })
438 .into_any_element(),
439 )
440 }
441
442 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
443 Task::ready(Ok(()))
444 }
445}
446
447fn render_accept_terms(
448 view_kind: LanguageModelProviderTosView,
449 accept_terms_of_service_in_progress: bool,
450 accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static,
451) -> impl IntoElement {
452 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
453 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState);
454
455 let terms_button = Button::new("terms_of_service", "Terms of Service")
456 .style(ButtonStyle::Subtle)
457 .icon(IconName::ArrowUpRight)
458 .icon_color(Color::Muted)
459 .icon_size(IconSize::XSmall)
460 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
461 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
462
463 let button_container = h_flex().child(
464 Button::new("accept_terms", "I accept the Terms of Service")
465 .when(!thread_empty_state, |this| {
466 this.full_width()
467 .style(ButtonStyle::Tinted(TintColor::Accent))
468 .icon(IconName::Check)
469 .icon_position(IconPosition::Start)
470 .icon_size(IconSize::Small)
471 })
472 .when(thread_empty_state, |this| {
473 this.style(ButtonStyle::Tinted(TintColor::Warning))
474 .label_size(LabelSize::Small)
475 })
476 .disabled(accept_terms_of_service_in_progress)
477 .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)),
478 );
479
480 if thread_empty_state {
481 h_flex()
482 .w_full()
483 .flex_wrap()
484 .justify_between()
485 .child(
486 h_flex()
487 .child(
488 Label::new("To start using Zed AI, please read and accept the")
489 .size(LabelSize::Small),
490 )
491 .child(terms_button),
492 )
493 .child(button_container)
494 } else {
495 v_flex()
496 .w_full()
497 .gap_2()
498 .child(
499 h_flex()
500 .flex_wrap()
501 .when(thread_fresh_start, |this| this.justify_center())
502 .child(Label::new(
503 "To start using Zed AI, please read and accept the",
504 ))
505 .child(terms_button),
506 )
507 .child({
508 match view_kind {
509 LanguageModelProviderTosView::PromptEditorPopup => {
510 button_container.w_full().justify_end()
511 }
512 LanguageModelProviderTosView::Configuration => {
513 button_container.w_full().justify_start()
514 }
515 LanguageModelProviderTosView::ThreadFreshStart => {
516 button_container.w_full().justify_center()
517 }
518 LanguageModelProviderTosView::ThreadEmptyState => div().w_0(),
519 }
520 })
521 }
522}
523
524pub struct CloudLanguageModel {
525 id: LanguageModelId,
526 model: Arc<zed_llm_client::LanguageModel>,
527 llm_api_token: LlmApiToken,
528 client: Arc<Client>,
529 request_limiter: RateLimiter,
530}
531
532struct PerformLlmCompletionResponse {
533 response: Response<AsyncBody>,
534 usage: Option<ModelRequestUsage>,
535 tool_use_limit_reached: bool,
536 includes_status_messages: bool,
537}
538
539impl CloudLanguageModel {
540 async fn perform_llm_completion(
541 client: Arc<Client>,
542 llm_api_token: LlmApiToken,
543 app_version: Option<SemanticVersion>,
544 body: CompletionBody,
545 use_cloud: bool,
546 ) -> Result<PerformLlmCompletionResponse> {
547 let http_client = &client.http_client();
548
549 let mut token = llm_api_token.acquire(&client).await?;
550 let mut refreshed_token = false;
551
552 loop {
553 let request_builder = http_client::Request::builder().method(Method::POST).uri(
554 http_client
555 .build_zed_llm_url("/completions", &[], use_cloud)?
556 .as_ref(),
557 );
558 let request_builder = if let Some(app_version) = app_version {
559 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
560 } else {
561 request_builder
562 };
563
564 let request = request_builder
565 .header("Content-Type", "application/json")
566 .header("Authorization", format!("Bearer {token}"))
567 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
568 .body(serde_json::to_string(&body)?.into())?;
569 let mut response = http_client.send(request).await?;
570 let status = response.status();
571 if status.is_success() {
572 let includes_status_messages = response
573 .headers()
574 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
575 .is_some();
576
577 let tool_use_limit_reached = response
578 .headers()
579 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
580 .is_some();
581
582 let usage = if includes_status_messages {
583 None
584 } else {
585 ModelRequestUsage::from_headers(response.headers()).ok()
586 };
587
588 return Ok(PerformLlmCompletionResponse {
589 response,
590 usage,
591 includes_status_messages,
592 tool_use_limit_reached,
593 });
594 }
595
596 if !refreshed_token
597 && response
598 .headers()
599 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
600 .is_some()
601 {
602 token = llm_api_token.refresh(&client).await?;
603 refreshed_token = true;
604 continue;
605 }
606
607 if status == StatusCode::FORBIDDEN
608 && response
609 .headers()
610 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
611 .is_some()
612 {
613 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
614 .headers()
615 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
616 .and_then(|resource| resource.to_str().ok())
617 {
618 if let Some(plan) = response
619 .headers()
620 .get(CURRENT_PLAN_HEADER_NAME)
621 .and_then(|plan| plan.to_str().ok())
622 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
623 {
624 let plan = match plan {
625 zed_llm_client::Plan::ZedFree => Plan::Free,
626 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
627 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
628 };
629 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
630 }
631 }
632 } else if status == StatusCode::PAYMENT_REQUIRED {
633 return Err(anyhow!(PaymentRequiredError));
634 }
635
636 let mut body = String::new();
637 let headers = response.headers().clone();
638 response.body_mut().read_to_string(&mut body).await?;
639 return Err(anyhow!(ApiError {
640 status,
641 body,
642 headers
643 }));
644 }
645 }
646}
647
648#[derive(Debug, Error)]
649#[error("cloud language model request failed with status {status}: {body}")]
650struct ApiError {
651 status: StatusCode,
652 body: String,
653 headers: HeaderMap<HeaderValue>,
654}
655
656impl From<ApiError> for LanguageModelCompletionError {
657 fn from(error: ApiError) -> Self {
658 let retry_after = None;
659 LanguageModelCompletionError::from_http_status(
660 PROVIDER_NAME,
661 error.status,
662 error.body,
663 retry_after,
664 )
665 }
666}
667
668impl LanguageModel for CloudLanguageModel {
669 fn id(&self) -> LanguageModelId {
670 self.id.clone()
671 }
672
673 fn name(&self) -> LanguageModelName {
674 LanguageModelName::from(self.model.display_name.clone())
675 }
676
677 fn provider_id(&self) -> LanguageModelProviderId {
678 PROVIDER_ID
679 }
680
681 fn provider_name(&self) -> LanguageModelProviderName {
682 PROVIDER_NAME
683 }
684
685 fn upstream_provider_id(&self) -> LanguageModelProviderId {
686 use zed_llm_client::LanguageModelProvider::*;
687 match self.model.provider {
688 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
689 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
690 Google => language_model::GOOGLE_PROVIDER_ID,
691 }
692 }
693
694 fn upstream_provider_name(&self) -> LanguageModelProviderName {
695 use zed_llm_client::LanguageModelProvider::*;
696 match self.model.provider {
697 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
698 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
699 Google => language_model::GOOGLE_PROVIDER_NAME,
700 }
701 }
702
703 fn supports_tools(&self) -> bool {
704 self.model.supports_tools
705 }
706
707 fn supports_images(&self) -> bool {
708 self.model.supports_images
709 }
710
711 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
712 match choice {
713 LanguageModelToolChoice::Auto
714 | LanguageModelToolChoice::Any
715 | LanguageModelToolChoice::None => true,
716 }
717 }
718
719 fn supports_burn_mode(&self) -> bool {
720 self.model.supports_max_mode
721 }
722
723 fn telemetry_id(&self) -> String {
724 format!("zed.dev/{}", self.model.id)
725 }
726
727 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
728 match self.model.provider {
729 zed_llm_client::LanguageModelProvider::Anthropic
730 | zed_llm_client::LanguageModelProvider::OpenAi => {
731 LanguageModelToolSchemaFormat::JsonSchema
732 }
733 zed_llm_client::LanguageModelProvider::Google => {
734 LanguageModelToolSchemaFormat::JsonSchemaSubset
735 }
736 }
737 }
738
739 fn max_token_count(&self) -> u64 {
740 self.model.max_token_count as u64
741 }
742
743 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
744 self.model
745 .max_token_count_in_max_mode
746 .filter(|_| self.model.supports_max_mode)
747 .map(|max_token_count| max_token_count as u64)
748 }
749
750 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
751 match &self.model.provider {
752 zed_llm_client::LanguageModelProvider::Anthropic => {
753 Some(LanguageModelCacheConfiguration {
754 min_total_token: 2_048,
755 should_speculate: true,
756 max_cache_anchors: 4,
757 })
758 }
759 zed_llm_client::LanguageModelProvider::OpenAi
760 | zed_llm_client::LanguageModelProvider::Google => None,
761 }
762 }
763
764 fn count_tokens(
765 &self,
766 request: LanguageModelRequest,
767 cx: &App,
768 ) -> BoxFuture<'static, Result<u64>> {
769 match self.model.provider {
770 zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
771 zed_llm_client::LanguageModelProvider::OpenAi => {
772 let model = match open_ai::Model::from_id(&self.model.id.0) {
773 Ok(model) => model,
774 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
775 };
776 count_open_ai_tokens(request, model, cx)
777 }
778 zed_llm_client::LanguageModelProvider::Google => {
779 let client = self.client.clone();
780 let llm_api_token = self.llm_api_token.clone();
781 let model_id = self.model.id.to_string();
782 let generate_content_request =
783 into_google(request, model_id.clone(), GoogleModelMode::Default);
784 let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
785 async move {
786 let http_client = &client.http_client();
787 let token = llm_api_token.acquire(&client).await?;
788
789 let request_body = CountTokensBody {
790 provider: zed_llm_client::LanguageModelProvider::Google,
791 model: model_id,
792 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
793 generate_content_request,
794 })?,
795 };
796 let request = http_client::Request::builder()
797 .method(Method::POST)
798 .uri(
799 http_client
800 .build_zed_llm_url("/count_tokens", &[], use_cloud)?
801 .as_ref(),
802 )
803 .header("Content-Type", "application/json")
804 .header("Authorization", format!("Bearer {token}"))
805 .body(serde_json::to_string(&request_body)?.into())?;
806 let mut response = http_client.send(request).await?;
807 let status = response.status();
808 let headers = response.headers().clone();
809 let mut response_body = String::new();
810 response
811 .body_mut()
812 .read_to_string(&mut response_body)
813 .await?;
814
815 if status.is_success() {
816 let response_body: CountTokensResponse =
817 serde_json::from_str(&response_body)?;
818
819 Ok(response_body.tokens as u64)
820 } else {
821 Err(anyhow!(ApiError {
822 status,
823 body: response_body,
824 headers
825 }))
826 }
827 }
828 .boxed()
829 }
830 }
831 }
832
833 fn stream_completion(
834 &self,
835 request: LanguageModelRequest,
836 cx: &AsyncApp,
837 ) -> BoxFuture<
838 'static,
839 Result<
840 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
841 LanguageModelCompletionError,
842 >,
843 > {
844 let thread_id = request.thread_id.clone();
845 let prompt_id = request.prompt_id.clone();
846 let intent = request.intent;
847 let mode = request.mode;
848 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
849 let use_cloud = cx
850 .update(|cx| cx.has_flag::<ZedCloudFeatureFlag>())
851 .unwrap_or(false);
852 let thinking_allowed = request.thinking_allowed;
853 match self.model.provider {
854 zed_llm_client::LanguageModelProvider::Anthropic => {
855 let request = into_anthropic(
856 request,
857 self.model.id.to_string(),
858 1.0,
859 self.model.max_output_tokens as u64,
860 if thinking_allowed && self.model.id.0.ends_with("-thinking") {
861 AnthropicModelMode::Thinking {
862 budget_tokens: Some(4_096),
863 }
864 } else {
865 AnthropicModelMode::Default
866 },
867 );
868 let client = self.client.clone();
869 let llm_api_token = self.llm_api_token.clone();
870 let future = self.request_limiter.stream(async move {
871 let PerformLlmCompletionResponse {
872 response,
873 usage,
874 includes_status_messages,
875 tool_use_limit_reached,
876 } = Self::perform_llm_completion(
877 client.clone(),
878 llm_api_token,
879 app_version,
880 CompletionBody {
881 thread_id,
882 prompt_id,
883 intent,
884 mode,
885 provider: zed_llm_client::LanguageModelProvider::Anthropic,
886 model: request.model.clone(),
887 provider_request: serde_json::to_value(&request)
888 .map_err(|e| anyhow!(e))?,
889 },
890 use_cloud,
891 )
892 .await
893 .map_err(|err| match err.downcast::<ApiError>() {
894 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
895 Err(err) => anyhow!(err),
896 })?;
897
898 let mut mapper = AnthropicEventMapper::new();
899 Ok(map_cloud_completion_events(
900 Box::pin(
901 response_lines(response, includes_status_messages)
902 .chain(usage_updated_event(usage))
903 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
904 ),
905 move |event| mapper.map_event(event),
906 ))
907 });
908 async move { Ok(future.await?.boxed()) }.boxed()
909 }
910 zed_llm_client::LanguageModelProvider::OpenAi => {
911 let client = self.client.clone();
912 let model = match open_ai::Model::from_id(&self.model.id.0) {
913 Ok(model) => model,
914 Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
915 };
916 let request = into_open_ai(
917 request,
918 model.id(),
919 model.supports_parallel_tool_calls(),
920 None,
921 );
922 let llm_api_token = self.llm_api_token.clone();
923 let future = self.request_limiter.stream(async move {
924 let PerformLlmCompletionResponse {
925 response,
926 usage,
927 includes_status_messages,
928 tool_use_limit_reached,
929 } = Self::perform_llm_completion(
930 client.clone(),
931 llm_api_token,
932 app_version,
933 CompletionBody {
934 thread_id,
935 prompt_id,
936 intent,
937 mode,
938 provider: zed_llm_client::LanguageModelProvider::OpenAi,
939 model: request.model.clone(),
940 provider_request: serde_json::to_value(&request)
941 .map_err(|e| anyhow!(e))?,
942 },
943 use_cloud,
944 )
945 .await?;
946
947 let mut mapper = OpenAiEventMapper::new();
948 Ok(map_cloud_completion_events(
949 Box::pin(
950 response_lines(response, includes_status_messages)
951 .chain(usage_updated_event(usage))
952 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
953 ),
954 move |event| mapper.map_event(event),
955 ))
956 });
957 async move { Ok(future.await?.boxed()) }.boxed()
958 }
959 zed_llm_client::LanguageModelProvider::Google => {
960 let client = self.client.clone();
961 let request =
962 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
963 let llm_api_token = self.llm_api_token.clone();
964 let future = self.request_limiter.stream(async move {
965 let PerformLlmCompletionResponse {
966 response,
967 usage,
968 includes_status_messages,
969 tool_use_limit_reached,
970 } = Self::perform_llm_completion(
971 client.clone(),
972 llm_api_token,
973 app_version,
974 CompletionBody {
975 thread_id,
976 prompt_id,
977 intent,
978 mode,
979 provider: zed_llm_client::LanguageModelProvider::Google,
980 model: request.model.model_id.clone(),
981 provider_request: serde_json::to_value(&request)
982 .map_err(|e| anyhow!(e))?,
983 },
984 use_cloud,
985 )
986 .await?;
987
988 let mut mapper = GoogleEventMapper::new();
989 Ok(map_cloud_completion_events(
990 Box::pin(
991 response_lines(response, includes_status_messages)
992 .chain(usage_updated_event(usage))
993 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
994 ),
995 move |event| mapper.map_event(event),
996 ))
997 });
998 async move { Ok(future.await?.boxed()) }.boxed()
999 }
1000 }
1001 }
1002}
1003
1004#[derive(Serialize, Deserialize)]
1005#[serde(rename_all = "snake_case")]
1006pub enum CloudCompletionEvent<T> {
1007 Status(CompletionRequestStatus),
1008 Event(T),
1009}
1010
1011fn map_cloud_completion_events<T, F>(
1012 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
1013 mut map_callback: F,
1014) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1015where
1016 T: DeserializeOwned + 'static,
1017 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1018 + Send
1019 + 'static,
1020{
1021 stream
1022 .flat_map(move |event| {
1023 futures::stream::iter(match event {
1024 Err(error) => {
1025 vec![Err(LanguageModelCompletionError::from(error))]
1026 }
1027 Ok(CloudCompletionEvent::Status(event)) => {
1028 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1029 }
1030 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1031 })
1032 })
1033 .boxed()
1034}
1035
1036fn usage_updated_event<T>(
1037 usage: Option<ModelRequestUsage>,
1038) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1039 futures::stream::iter(usage.map(|usage| {
1040 Ok(CloudCompletionEvent::Status(
1041 CompletionRequestStatus::UsageUpdated {
1042 amount: usage.amount as usize,
1043 limit: usage.limit,
1044 },
1045 ))
1046 }))
1047}
1048
1049fn tool_use_limit_reached_event<T>(
1050 tool_use_limit_reached: bool,
1051) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1052 futures::stream::iter(tool_use_limit_reached.then(|| {
1053 Ok(CloudCompletionEvent::Status(
1054 CompletionRequestStatus::ToolUseLimitReached,
1055 ))
1056 }))
1057}
1058
1059fn response_lines<T: DeserializeOwned>(
1060 response: Response<AsyncBody>,
1061 includes_status_messages: bool,
1062) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1063 futures::stream::try_unfold(
1064 (String::new(), BufReader::new(response.into_body())),
1065 move |(mut line, mut body)| async move {
1066 match body.read_line(&mut line).await {
1067 Ok(0) => Ok(None),
1068 Ok(_) => {
1069 let event = if includes_status_messages {
1070 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1071 } else {
1072 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1073 };
1074
1075 line.clear();
1076 Ok(Some((event, (line, body))))
1077 }
1078 Err(e) => Err(e.into()),
1079 }
1080 },
1081 )
1082}
1083
1084#[derive(IntoElement, RegisterComponent)]
1085struct ZedAiConfiguration {
1086 is_connected: bool,
1087 plan: Option<proto::Plan>,
1088 subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1089 eligible_for_trial: bool,
1090 has_accepted_terms_of_service: bool,
1091 accept_terms_of_service_in_progress: bool,
1092 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1093 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1094}
1095
1096impl RenderOnce for ZedAiConfiguration {
1097 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1098 const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1099
1100 let is_pro = self.plan == Some(proto::Plan::ZedPro);
1101 let subscription_text = match (self.plan, self.subscription_period) {
1102 (Some(proto::Plan::ZedPro), Some(_)) => {
1103 "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1104 }
1105 (Some(proto::Plan::ZedProTrial), Some(_)) => {
1106 "You have access to Zed's hosted LLMs through your Zed Pro trial."
1107 }
1108 (Some(proto::Plan::Free), Some(_)) => {
1109 "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1110 }
1111 _ => {
1112 if self.eligible_for_trial {
1113 "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1114 } else {
1115 "Subscribe for access to Zed's hosted LLMs."
1116 }
1117 }
1118 };
1119 let manage_subscription_buttons = if is_pro {
1120 h_flex().child(
1121 Button::new("manage_settings", "Manage Subscription")
1122 .style(ButtonStyle::Tinted(TintColor::Accent))
1123 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1124 )
1125 } else {
1126 h_flex()
1127 .gap_2()
1128 .child(
1129 Button::new("learn_more", "Learn more")
1130 .style(ButtonStyle::Subtle)
1131 .on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)),
1132 )
1133 .child(
1134 Button::new(
1135 "upgrade",
1136 if self.plan.is_none() && self.eligible_for_trial {
1137 "Start Trial"
1138 } else {
1139 "Upgrade"
1140 },
1141 )
1142 .style(ButtonStyle::Subtle)
1143 .color(Color::Accent)
1144 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1145 )
1146 };
1147
1148 if self.is_connected {
1149 v_flex()
1150 .gap_3()
1151 .w_full()
1152 .when(!self.has_accepted_terms_of_service, |this| {
1153 this.child(render_accept_terms(
1154 LanguageModelProviderTosView::Configuration,
1155 self.accept_terms_of_service_in_progress,
1156 {
1157 let callback = self.accept_terms_of_service_callback.clone();
1158 move |window, cx| (callback)(window, cx)
1159 },
1160 ))
1161 })
1162 .when(self.has_accepted_terms_of_service, |this| {
1163 this.child(subscription_text)
1164 .child(manage_subscription_buttons)
1165 })
1166 } else {
1167 v_flex()
1168 .gap_2()
1169 .child(Label::new("Use Zed AI to access hosted language models."))
1170 .child(
1171 Button::new("sign_in", "Sign In")
1172 .icon_color(Color::Muted)
1173 .icon(IconName::Github)
1174 .icon_position(IconPosition::Start)
1175 .on_click({
1176 let callback = self.sign_in_callback.clone();
1177 move |_, window, cx| (callback)(window, cx)
1178 }),
1179 )
1180 }
1181 }
1182}
1183
1184struct ConfigurationView {
1185 state: Entity<State>,
1186 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1187 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1188}
1189
1190impl ConfigurationView {
1191 fn new(state: Entity<State>) -> Self {
1192 let accept_terms_of_service_callback = Arc::new({
1193 let state = state.clone();
1194 move |_window: &mut Window, cx: &mut App| {
1195 state.update(cx, |state, cx| {
1196 state.accept_terms_of_service(cx);
1197 });
1198 }
1199 });
1200
1201 let sign_in_callback = Arc::new({
1202 let state = state.clone();
1203 move |_window: &mut Window, cx: &mut App| {
1204 state.update(cx, |state, cx| {
1205 state.authenticate(cx).detach_and_log_err(cx);
1206 });
1207 }
1208 });
1209
1210 Self {
1211 state,
1212 accept_terms_of_service_callback,
1213 sign_in_callback,
1214 }
1215 }
1216}
1217
1218impl Render for ConfigurationView {
1219 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1220 let state = self.state.read(cx);
1221 let user_store = state.user_store.read(cx);
1222
1223 ZedAiConfiguration {
1224 is_connected: !state.is_signed_out(),
1225 plan: user_store.current_plan(),
1226 subscription_period: user_store.subscription_period(),
1227 eligible_for_trial: user_store.trial_started_at().is_none(),
1228 has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
1229 accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
1230 accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
1231 sign_in_callback: self.sign_in_callback.clone(),
1232 }
1233 }
1234}
1235
1236impl Component for ZedAiConfiguration {
1237 fn scope() -> ComponentScope {
1238 ComponentScope::Agent
1239 }
1240
1241 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1242 fn configuration(
1243 is_connected: bool,
1244 plan: Option<proto::Plan>,
1245 eligible_for_trial: bool,
1246 has_accepted_terms_of_service: bool,
1247 ) -> AnyElement {
1248 ZedAiConfiguration {
1249 is_connected,
1250 plan,
1251 subscription_period: plan
1252 .is_some()
1253 .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1254 eligible_for_trial,
1255 has_accepted_terms_of_service,
1256 accept_terms_of_service_in_progress: false,
1257 accept_terms_of_service_callback: Arc::new(|_, _| {}),
1258 sign_in_callback: Arc::new(|_, _| {}),
1259 }
1260 .into_any_element()
1261 }
1262
1263 Some(
1264 v_flex()
1265 .p_4()
1266 .gap_4()
1267 .children(vec![
1268 single_example("Not connected", configuration(false, None, false, true)),
1269 single_example(
1270 "Accept Terms of Service",
1271 configuration(true, None, true, false),
1272 ),
1273 single_example(
1274 "No Plan - Not eligible for trial",
1275 configuration(true, None, false, true),
1276 ),
1277 single_example(
1278 "No Plan - Eligible for trial",
1279 configuration(true, None, true, true),
1280 ),
1281 single_example(
1282 "Free Plan",
1283 configuration(true, Some(proto::Plan::Free), true, true),
1284 ),
1285 single_example(
1286 "Zed Pro Trial Plan",
1287 configuration(true, Some(proto::Plan::ZedProTrial), true, true),
1288 ),
1289 single_example(
1290 "Zed Pro Plan",
1291 configuration(true, Some(proto::Plan::ZedPro), true, true),
1292 ),
1293 ])
1294 .into_any_element(),
1295 )
1296 }
1297}