1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use chrono::{DateTime, Utc};
4use client::{Client, ModelRequestUsage, UserStore, zed_urls};
5use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag};
6use futures::{
7 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
8};
9use google_ai::GoogleModelMode;
10use gpui::{
11 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
12};
13use http_client::http::{HeaderMap, HeaderValue};
14use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
15use language_model::{
16 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
17 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
18 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
19 LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
20 LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
21 ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
22};
23use proto::Plan;
24use release_channel::AppVersion;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize, de::DeserializeOwned};
27use settings::SettingsStore;
28use smol::io::{AsyncReadExt, BufReader};
29use std::pin::Pin;
30use std::str::FromStr as _;
31use std::sync::Arc;
32use std::time::Duration;
33use thiserror::Error;
34use ui::{TintColor, prelude::*};
35use util::{ResultExt as _, maybe};
36use zed_llm_client::{
37 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
38 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
39 ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
40 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
41 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
42};
43
44use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
45use crate::provider::google::{GoogleEventMapper, into_google};
46use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
47
48const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
49const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
50
51#[derive(Default, Clone, Debug, PartialEq)]
52pub struct ZedDotDevSettings {
53 pub available_models: Vec<AvailableModel>,
54}
55
56#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
57#[serde(rename_all = "lowercase")]
58pub enum AvailableProvider {
59 Anthropic,
60 OpenAi,
61 Google,
62}
63
64#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
65pub struct AvailableModel {
66 /// The provider of the language model.
67 pub provider: AvailableProvider,
68 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
69 pub name: String,
70 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
71 pub display_name: Option<String>,
72 /// The size of the context window, indicating the maximum number of tokens the model can process.
73 pub max_tokens: usize,
74 /// The maximum number of output tokens allowed by the model.
75 pub max_output_tokens: Option<u64>,
76 /// The maximum number of completion tokens allowed by the model (o1-* only)
77 pub max_completion_tokens: Option<u64>,
78 /// Override this model with a different Anthropic model for tool calls.
79 pub tool_override: Option<String>,
80 /// Indicates whether this custom model supports caching.
81 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
82 /// The default temperature to use for this model.
83 pub default_temperature: Option<f32>,
84 /// Any extra beta headers to provide when using the model.
85 #[serde(default)]
86 pub extra_beta_headers: Vec<String>,
87 /// The model's mode (e.g. thinking)
88 pub mode: Option<ModelMode>,
89}
90
91#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
92#[serde(tag = "type", rename_all = "lowercase")]
93pub enum ModelMode {
94 #[default]
95 Default,
96 Thinking {
97 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
98 budget_tokens: Option<u32>,
99 },
100}
101
102impl From<ModelMode> for AnthropicModelMode {
103 fn from(value: ModelMode) -> Self {
104 match value {
105 ModelMode::Default => AnthropicModelMode::Default,
106 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
107 }
108 }
109}
110
111pub struct CloudLanguageModelProvider {
112 client: Arc<Client>,
113 state: gpui::Entity<State>,
114 _maintain_client_status: Task<()>,
115}
116
117pub struct State {
118 client: Arc<Client>,
119 llm_api_token: LlmApiToken,
120 user_store: Entity<UserStore>,
121 status: client::Status,
122 accept_terms_of_service_task: Option<Task<Result<()>>>,
123 models: Vec<Arc<zed_llm_client::LanguageModel>>,
124 default_model: Option<Arc<zed_llm_client::LanguageModel>>,
125 default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
126 recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
127 _fetch_models_task: Task<()>,
128 _settings_subscription: Subscription,
129 _llm_token_subscription: Subscription,
130}
131
132impl State {
133 fn new(
134 client: Arc<Client>,
135 user_store: Entity<UserStore>,
136 status: client::Status,
137 cx: &mut Context<Self>,
138 ) -> Self {
139 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
140 let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
141
142 Self {
143 client: client.clone(),
144 llm_api_token: LlmApiToken::default(),
145 user_store,
146 status,
147 accept_terms_of_service_task: None,
148 models: Vec::new(),
149 default_model: None,
150 default_fast_model: None,
151 recommended_models: Vec::new(),
152 _fetch_models_task: cx.spawn(async move |this, cx| {
153 maybe!(async move {
154 let (client, llm_api_token) = this
155 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
156
157 loop {
158 let status = this.read_with(cx, |this, _cx| this.status)?;
159 if matches!(status, client::Status::Connected { .. }) {
160 break;
161 }
162
163 cx.background_executor()
164 .timer(Duration::from_millis(100))
165 .await;
166 }
167
168 let response = Self::fetch_models(client, llm_api_token, use_cloud).await?;
169 cx.update(|cx| {
170 this.update(cx, |this, cx| {
171 let mut models = Vec::new();
172
173 for model in response.models {
174 models.push(Arc::new(model.clone()));
175
176 // Right now we represent thinking variants of models as separate models on the client,
177 // so we need to insert variants for any model that supports thinking.
178 if model.supports_thinking {
179 models.push(Arc::new(zed_llm_client::LanguageModel {
180 id: zed_llm_client::LanguageModelId(
181 format!("{}-thinking", model.id).into(),
182 ),
183 display_name: format!("{} Thinking", model.display_name),
184 ..model
185 }));
186 }
187 }
188
189 this.default_model = models
190 .iter()
191 .find(|model| model.id == response.default_model)
192 .cloned();
193 this.default_fast_model = models
194 .iter()
195 .find(|model| model.id == response.default_fast_model)
196 .cloned();
197 this.recommended_models = response
198 .recommended_models
199 .iter()
200 .filter_map(|id| models.iter().find(|model| &model.id == id))
201 .cloned()
202 .collect();
203 this.models = models;
204 cx.notify();
205 })
206 })??;
207
208 anyhow::Ok(())
209 })
210 .await
211 .context("failed to fetch Zed models")
212 .log_err();
213 }),
214 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
215 cx.notify();
216 }),
217 _llm_token_subscription: cx.subscribe(
218 &refresh_llm_token_listener,
219 |this, _listener, _event, cx| {
220 let client = this.client.clone();
221 let llm_api_token = this.llm_api_token.clone();
222 cx.spawn(async move |_this, _cx| {
223 llm_api_token.refresh(&client).await?;
224 anyhow::Ok(())
225 })
226 .detach_and_log_err(cx);
227 },
228 ),
229 }
230 }
231
232 fn is_signed_out(&self) -> bool {
233 self.status.is_signed_out()
234 }
235
236 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
237 let client = self.client.clone();
238 cx.spawn(async move |state, cx| {
239 client
240 .authenticate_and_connect(true, &cx)
241 .await
242 .into_response()?;
243 state.update(cx, |_, cx| cx.notify())
244 })
245 }
246
247 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
248 self.user_store
249 .read(cx)
250 .current_user_has_accepted_terms()
251 .unwrap_or(false)
252 }
253
254 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
255 let user_store = self.user_store.clone();
256 self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| {
257 let _ = user_store
258 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
259 .await;
260 this.update(cx, |this, cx| {
261 this.accept_terms_of_service_task = None;
262 cx.notify()
263 })
264 }));
265 }
266
267 async fn fetch_models(
268 client: Arc<Client>,
269 llm_api_token: LlmApiToken,
270 use_cloud: bool,
271 ) -> Result<ListModelsResponse> {
272 let http_client = &client.http_client();
273 let token = llm_api_token.acquire(&client).await?;
274
275 let request = http_client::Request::builder()
276 .method(Method::GET)
277 .uri(
278 http_client
279 .build_zed_llm_url("/models", &[], use_cloud)?
280 .as_ref(),
281 )
282 .header("Authorization", format!("Bearer {token}"))
283 .body(AsyncBody::empty())?;
284 let mut response = http_client
285 .send(request)
286 .await
287 .context("failed to send list models request")?;
288
289 if response.status().is_success() {
290 let mut body = String::new();
291 response.body_mut().read_to_string(&mut body).await?;
292 return Ok(serde_json::from_str(&body)?);
293 } else {
294 let mut body = String::new();
295 response.body_mut().read_to_string(&mut body).await?;
296 anyhow::bail!(
297 "error listing models.\nStatus: {:?}\nBody: {body}",
298 response.status(),
299 );
300 }
301 }
302}
303
304impl CloudLanguageModelProvider {
305 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
306 let mut status_rx = client.status();
307 let status = *status_rx.borrow();
308
309 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
310
311 let state_ref = state.downgrade();
312 let maintain_client_status = cx.spawn(async move |cx| {
313 while let Some(status) = status_rx.next().await {
314 if let Some(this) = state_ref.upgrade() {
315 _ = this.update(cx, |this, cx| {
316 if this.status != status {
317 this.status = status;
318 cx.notify();
319 }
320 });
321 } else {
322 break;
323 }
324 }
325 });
326
327 Self {
328 client,
329 state: state.clone(),
330 _maintain_client_status: maintain_client_status,
331 }
332 }
333
334 fn create_language_model(
335 &self,
336 model: Arc<zed_llm_client::LanguageModel>,
337 llm_api_token: LlmApiToken,
338 ) -> Arc<dyn LanguageModel> {
339 Arc::new(CloudLanguageModel {
340 id: LanguageModelId(SharedString::from(model.id.0.clone())),
341 model,
342 llm_api_token: llm_api_token.clone(),
343 client: self.client.clone(),
344 request_limiter: RateLimiter::new(4),
345 })
346 }
347}
348
349impl LanguageModelProviderState for CloudLanguageModelProvider {
350 type ObservableEntity = State;
351
352 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
353 Some(self.state.clone())
354 }
355}
356
357impl LanguageModelProvider for CloudLanguageModelProvider {
358 fn id(&self) -> LanguageModelProviderId {
359 PROVIDER_ID
360 }
361
362 fn name(&self) -> LanguageModelProviderName {
363 PROVIDER_NAME
364 }
365
366 fn icon(&self) -> IconName {
367 IconName::AiZed
368 }
369
370 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
371 let default_model = self.state.read(cx).default_model.clone()?;
372 let llm_api_token = self.state.read(cx).llm_api_token.clone();
373 Some(self.create_language_model(default_model, llm_api_token))
374 }
375
376 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
377 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
378 let llm_api_token = self.state.read(cx).llm_api_token.clone();
379 Some(self.create_language_model(default_fast_model, llm_api_token))
380 }
381
382 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
383 let llm_api_token = self.state.read(cx).llm_api_token.clone();
384 self.state
385 .read(cx)
386 .recommended_models
387 .iter()
388 .cloned()
389 .map(|model| self.create_language_model(model, llm_api_token.clone()))
390 .collect()
391 }
392
393 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
394 let llm_api_token = self.state.read(cx).llm_api_token.clone();
395 self.state
396 .read(cx)
397 .models
398 .iter()
399 .cloned()
400 .map(|model| self.create_language_model(model, llm_api_token.clone()))
401 .collect()
402 }
403
404 fn is_authenticated(&self, cx: &App) -> bool {
405 let state = self.state.read(cx);
406 !state.is_signed_out() && state.has_accepted_terms_of_service(cx)
407 }
408
409 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
410 Task::ready(Ok(()))
411 }
412
413 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
414 cx.new(|_| ConfigurationView::new(self.state.clone()))
415 .into()
416 }
417
418 fn must_accept_terms(&self, cx: &App) -> bool {
419 !self.state.read(cx).has_accepted_terms_of_service(cx)
420 }
421
422 fn render_accept_terms(
423 &self,
424 view: LanguageModelProviderTosView,
425 cx: &mut App,
426 ) -> Option<AnyElement> {
427 let state = self.state.read(cx);
428 if state.has_accepted_terms_of_service(cx) {
429 return None;
430 }
431 Some(
432 render_accept_terms(view, state.accept_terms_of_service_task.is_some(), {
433 let state = self.state.clone();
434 move |_window, cx| {
435 state.update(cx, |state, cx| state.accept_terms_of_service(cx));
436 }
437 })
438 .into_any_element(),
439 )
440 }
441
442 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
443 Task::ready(Ok(()))
444 }
445}
446
447fn render_accept_terms(
448 view_kind: LanguageModelProviderTosView,
449 accept_terms_of_service_in_progress: bool,
450 accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static,
451) -> impl IntoElement {
452 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
453 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState);
454
455 let terms_button = Button::new("terms_of_service", "Terms of Service")
456 .style(ButtonStyle::Subtle)
457 .icon(IconName::ArrowUpRight)
458 .icon_color(Color::Muted)
459 .icon_size(IconSize::XSmall)
460 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
461 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
462
463 let button_container = h_flex().child(
464 Button::new("accept_terms", "I accept the Terms of Service")
465 .when(!thread_empty_state, |this| {
466 this.full_width()
467 .style(ButtonStyle::Tinted(TintColor::Accent))
468 .icon(IconName::Check)
469 .icon_position(IconPosition::Start)
470 .icon_size(IconSize::Small)
471 })
472 .when(thread_empty_state, |this| {
473 this.style(ButtonStyle::Tinted(TintColor::Warning))
474 .label_size(LabelSize::Small)
475 })
476 .disabled(accept_terms_of_service_in_progress)
477 .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)),
478 );
479
480 if thread_empty_state {
481 h_flex()
482 .w_full()
483 .flex_wrap()
484 .justify_between()
485 .child(
486 h_flex()
487 .child(
488 Label::new("To start using Zed AI, please read and accept the")
489 .size(LabelSize::Small),
490 )
491 .child(terms_button),
492 )
493 .child(button_container)
494 } else {
495 v_flex()
496 .w_full()
497 .gap_2()
498 .child(
499 h_flex()
500 .flex_wrap()
501 .when(thread_fresh_start, |this| this.justify_center())
502 .child(Label::new(
503 "To start using Zed AI, please read and accept the",
504 ))
505 .child(terms_button),
506 )
507 .child({
508 match view_kind {
509 LanguageModelProviderTosView::PromptEditorPopup => {
510 button_container.w_full().justify_end()
511 }
512 LanguageModelProviderTosView::Configuration => {
513 button_container.w_full().justify_start()
514 }
515 LanguageModelProviderTosView::ThreadFreshStart => {
516 button_container.w_full().justify_center()
517 }
518 LanguageModelProviderTosView::ThreadEmptyState => div().w_0(),
519 }
520 })
521 }
522}
523
524pub struct CloudLanguageModel {
525 id: LanguageModelId,
526 model: Arc<zed_llm_client::LanguageModel>,
527 llm_api_token: LlmApiToken,
528 client: Arc<Client>,
529 request_limiter: RateLimiter,
530}
531
532struct PerformLlmCompletionResponse {
533 response: Response<AsyncBody>,
534 usage: Option<ModelRequestUsage>,
535 tool_use_limit_reached: bool,
536 includes_status_messages: bool,
537}
538
539impl CloudLanguageModel {
540 async fn perform_llm_completion(
541 client: Arc<Client>,
542 llm_api_token: LlmApiToken,
543 app_version: Option<SemanticVersion>,
544 body: CompletionBody,
545 use_cloud: bool,
546 ) -> Result<PerformLlmCompletionResponse> {
547 let http_client = &client.http_client();
548
549 let mut token = llm_api_token.acquire(&client).await?;
550 let mut refreshed_token = false;
551
552 loop {
553 let request_builder = http_client::Request::builder().method(Method::POST).uri(
554 http_client
555 .build_zed_llm_url("/completions", &[], use_cloud)?
556 .as_ref(),
557 );
558 let request_builder = if let Some(app_version) = app_version {
559 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
560 } else {
561 request_builder
562 };
563
564 let request = request_builder
565 .header("Content-Type", "application/json")
566 .header("Authorization", format!("Bearer {token}"))
567 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
568 .body(serde_json::to_string(&body)?.into())?;
569 let mut response = http_client.send(request).await?;
570 let status = response.status();
571 if status.is_success() {
572 let includes_status_messages = response
573 .headers()
574 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
575 .is_some();
576
577 let tool_use_limit_reached = response
578 .headers()
579 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
580 .is_some();
581
582 let usage = if includes_status_messages {
583 None
584 } else {
585 ModelRequestUsage::from_headers(response.headers()).ok()
586 };
587
588 return Ok(PerformLlmCompletionResponse {
589 response,
590 usage,
591 includes_status_messages,
592 tool_use_limit_reached,
593 });
594 }
595
596 if !refreshed_token
597 && response
598 .headers()
599 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
600 .is_some()
601 {
602 token = llm_api_token.refresh(&client).await?;
603 refreshed_token = true;
604 continue;
605 }
606
607 if status == StatusCode::FORBIDDEN
608 && response
609 .headers()
610 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
611 .is_some()
612 {
613 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
614 .headers()
615 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
616 .and_then(|resource| resource.to_str().ok())
617 {
618 if let Some(plan) = response
619 .headers()
620 .get(CURRENT_PLAN_HEADER_NAME)
621 .and_then(|plan| plan.to_str().ok())
622 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
623 {
624 let plan = match plan {
625 zed_llm_client::Plan::ZedFree => Plan::Free,
626 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
627 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
628 };
629 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
630 }
631 }
632 } else if status == StatusCode::PAYMENT_REQUIRED {
633 return Err(anyhow!(PaymentRequiredError));
634 }
635
636 let mut body = String::new();
637 let headers = response.headers().clone();
638 response.body_mut().read_to_string(&mut body).await?;
639 return Err(anyhow!(ApiError {
640 status,
641 body,
642 headers
643 }));
644 }
645 }
646}
647
648#[derive(Debug, Error)]
649#[error("cloud language model request failed with status {status}: {body}")]
650struct ApiError {
651 status: StatusCode,
652 body: String,
653 headers: HeaderMap<HeaderValue>,
654}
655
656impl From<ApiError> for LanguageModelCompletionError {
657 fn from(error: ApiError) -> Self {
658 let retry_after = None;
659 LanguageModelCompletionError::from_http_status(
660 PROVIDER_NAME,
661 error.status,
662 error.body,
663 retry_after,
664 )
665 }
666}
667
668impl LanguageModel for CloudLanguageModel {
669 fn id(&self) -> LanguageModelId {
670 self.id.clone()
671 }
672
673 fn name(&self) -> LanguageModelName {
674 LanguageModelName::from(self.model.display_name.clone())
675 }
676
677 fn provider_id(&self) -> LanguageModelProviderId {
678 PROVIDER_ID
679 }
680
681 fn provider_name(&self) -> LanguageModelProviderName {
682 PROVIDER_NAME
683 }
684
685 fn upstream_provider_id(&self) -> LanguageModelProviderId {
686 use zed_llm_client::LanguageModelProvider::*;
687 match self.model.provider {
688 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
689 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
690 Google => language_model::GOOGLE_PROVIDER_ID,
691 }
692 }
693
694 fn upstream_provider_name(&self) -> LanguageModelProviderName {
695 use zed_llm_client::LanguageModelProvider::*;
696 match self.model.provider {
697 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
698 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
699 Google => language_model::GOOGLE_PROVIDER_NAME,
700 }
701 }
702
703 fn supports_tools(&self) -> bool {
704 self.model.supports_tools
705 }
706
707 fn supports_images(&self) -> bool {
708 self.model.supports_images
709 }
710
711 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
712 match choice {
713 LanguageModelToolChoice::Auto
714 | LanguageModelToolChoice::Any
715 | LanguageModelToolChoice::None => true,
716 }
717 }
718
719 fn supports_burn_mode(&self) -> bool {
720 self.model.supports_max_mode
721 }
722
723 fn telemetry_id(&self) -> String {
724 format!("zed.dev/{}", self.model.id)
725 }
726
727 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
728 match self.model.provider {
729 zed_llm_client::LanguageModelProvider::Anthropic
730 | zed_llm_client::LanguageModelProvider::OpenAi => {
731 LanguageModelToolSchemaFormat::JsonSchema
732 }
733 zed_llm_client::LanguageModelProvider::Google => {
734 LanguageModelToolSchemaFormat::JsonSchemaSubset
735 }
736 }
737 }
738
739 fn max_token_count(&self) -> u64 {
740 self.model.max_token_count as u64
741 }
742
743 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
744 self.model
745 .max_token_count_in_max_mode
746 .filter(|_| self.model.supports_max_mode)
747 .map(|max_token_count| max_token_count as u64)
748 }
749
750 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
751 match &self.model.provider {
752 zed_llm_client::LanguageModelProvider::Anthropic => {
753 Some(LanguageModelCacheConfiguration {
754 min_total_token: 2_048,
755 should_speculate: true,
756 max_cache_anchors: 4,
757 })
758 }
759 zed_llm_client::LanguageModelProvider::OpenAi
760 | zed_llm_client::LanguageModelProvider::Google => None,
761 }
762 }
763
764 fn count_tokens(
765 &self,
766 request: LanguageModelRequest,
767 cx: &App,
768 ) -> BoxFuture<'static, Result<u64>> {
769 match self.model.provider {
770 zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
771 zed_llm_client::LanguageModelProvider::OpenAi => {
772 let model = match open_ai::Model::from_id(&self.model.id.0) {
773 Ok(model) => model,
774 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
775 };
776 count_open_ai_tokens(request, model, cx)
777 }
778 zed_llm_client::LanguageModelProvider::Google => {
779 let client = self.client.clone();
780 let llm_api_token = self.llm_api_token.clone();
781 let model_id = self.model.id.to_string();
782 let generate_content_request =
783 into_google(request, model_id.clone(), GoogleModelMode::Default);
784 let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
785 async move {
786 let http_client = &client.http_client();
787 let token = llm_api_token.acquire(&client).await?;
788
789 let request_body = CountTokensBody {
790 provider: zed_llm_client::LanguageModelProvider::Google,
791 model: model_id,
792 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
793 generate_content_request,
794 })?,
795 };
796 let request = http_client::Request::builder()
797 .method(Method::POST)
798 .uri(
799 http_client
800 .build_zed_llm_url("/count_tokens", &[], use_cloud)?
801 .as_ref(),
802 )
803 .header("Content-Type", "application/json")
804 .header("Authorization", format!("Bearer {token}"))
805 .body(serde_json::to_string(&request_body)?.into())?;
806 let mut response = http_client.send(request).await?;
807 let status = response.status();
808 let headers = response.headers().clone();
809 let mut response_body = String::new();
810 response
811 .body_mut()
812 .read_to_string(&mut response_body)
813 .await?;
814
815 if status.is_success() {
816 let response_body: CountTokensResponse =
817 serde_json::from_str(&response_body)?;
818
819 Ok(response_body.tokens as u64)
820 } else {
821 Err(anyhow!(ApiError {
822 status,
823 body: response_body,
824 headers
825 }))
826 }
827 }
828 .boxed()
829 }
830 }
831 }
832
833 fn stream_completion(
834 &self,
835 request: LanguageModelRequest,
836 cx: &AsyncApp,
837 ) -> BoxFuture<
838 'static,
839 Result<
840 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
841 LanguageModelCompletionError,
842 >,
843 > {
844 let thread_id = request.thread_id.clone();
845 let prompt_id = request.prompt_id.clone();
846 let intent = request.intent;
847 let mode = request.mode;
848 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
849 let use_cloud = cx
850 .update(|cx| cx.has_flag::<ZedCloudFeatureFlag>())
851 .unwrap_or(false);
852 match self.model.provider {
853 zed_llm_client::LanguageModelProvider::Anthropic => {
854 let request = into_anthropic(
855 request,
856 self.model.id.to_string(),
857 1.0,
858 self.model.max_output_tokens as u64,
859 if self.model.id.0.ends_with("-thinking") {
860 AnthropicModelMode::Thinking {
861 budget_tokens: Some(4_096),
862 }
863 } else {
864 AnthropicModelMode::Default
865 },
866 );
867 let client = self.client.clone();
868 let llm_api_token = self.llm_api_token.clone();
869 let future = self.request_limiter.stream(async move {
870 let PerformLlmCompletionResponse {
871 response,
872 usage,
873 includes_status_messages,
874 tool_use_limit_reached,
875 } = Self::perform_llm_completion(
876 client.clone(),
877 llm_api_token,
878 app_version,
879 CompletionBody {
880 thread_id,
881 prompt_id,
882 intent,
883 mode,
884 provider: zed_llm_client::LanguageModelProvider::Anthropic,
885 model: request.model.clone(),
886 provider_request: serde_json::to_value(&request)
887 .map_err(|e| anyhow!(e))?,
888 },
889 use_cloud,
890 )
891 .await
892 .map_err(|err| match err.downcast::<ApiError>() {
893 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
894 Err(err) => anyhow!(err),
895 })?;
896
897 let mut mapper = AnthropicEventMapper::new();
898 Ok(map_cloud_completion_events(
899 Box::pin(
900 response_lines(response, includes_status_messages)
901 .chain(usage_updated_event(usage))
902 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
903 ),
904 move |event| mapper.map_event(event),
905 ))
906 });
907 async move { Ok(future.await?.boxed()) }.boxed()
908 }
909 zed_llm_client::LanguageModelProvider::OpenAi => {
910 let client = self.client.clone();
911 let model = match open_ai::Model::from_id(&self.model.id.0) {
912 Ok(model) => model,
913 Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
914 };
915 let request = into_open_ai(
916 request,
917 model.id(),
918 model.supports_parallel_tool_calls(),
919 None,
920 );
921 let llm_api_token = self.llm_api_token.clone();
922 let future = self.request_limiter.stream(async move {
923 let PerformLlmCompletionResponse {
924 response,
925 usage,
926 includes_status_messages,
927 tool_use_limit_reached,
928 } = Self::perform_llm_completion(
929 client.clone(),
930 llm_api_token,
931 app_version,
932 CompletionBody {
933 thread_id,
934 prompt_id,
935 intent,
936 mode,
937 provider: zed_llm_client::LanguageModelProvider::OpenAi,
938 model: request.model.clone(),
939 provider_request: serde_json::to_value(&request)
940 .map_err(|e| anyhow!(e))?,
941 },
942 use_cloud,
943 )
944 .await?;
945
946 let mut mapper = OpenAiEventMapper::new();
947 Ok(map_cloud_completion_events(
948 Box::pin(
949 response_lines(response, includes_status_messages)
950 .chain(usage_updated_event(usage))
951 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
952 ),
953 move |event| mapper.map_event(event),
954 ))
955 });
956 async move { Ok(future.await?.boxed()) }.boxed()
957 }
958 zed_llm_client::LanguageModelProvider::Google => {
959 let client = self.client.clone();
960 let request =
961 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
962 let llm_api_token = self.llm_api_token.clone();
963 let future = self.request_limiter.stream(async move {
964 let PerformLlmCompletionResponse {
965 response,
966 usage,
967 includes_status_messages,
968 tool_use_limit_reached,
969 } = Self::perform_llm_completion(
970 client.clone(),
971 llm_api_token,
972 app_version,
973 CompletionBody {
974 thread_id,
975 prompt_id,
976 intent,
977 mode,
978 provider: zed_llm_client::LanguageModelProvider::Google,
979 model: request.model.model_id.clone(),
980 provider_request: serde_json::to_value(&request)
981 .map_err(|e| anyhow!(e))?,
982 },
983 use_cloud,
984 )
985 .await?;
986
987 let mut mapper = GoogleEventMapper::new();
988 Ok(map_cloud_completion_events(
989 Box::pin(
990 response_lines(response, includes_status_messages)
991 .chain(usage_updated_event(usage))
992 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
993 ),
994 move |event| mapper.map_event(event),
995 ))
996 });
997 async move { Ok(future.await?.boxed()) }.boxed()
998 }
999 }
1000 }
1001}
1002
1003#[derive(Serialize, Deserialize)]
1004#[serde(rename_all = "snake_case")]
1005pub enum CloudCompletionEvent<T> {
1006 Status(CompletionRequestStatus),
1007 Event(T),
1008}
1009
1010fn map_cloud_completion_events<T, F>(
1011 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
1012 mut map_callback: F,
1013) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1014where
1015 T: DeserializeOwned + 'static,
1016 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1017 + Send
1018 + 'static,
1019{
1020 stream
1021 .flat_map(move |event| {
1022 futures::stream::iter(match event {
1023 Err(error) => {
1024 vec![Err(LanguageModelCompletionError::from(error))]
1025 }
1026 Ok(CloudCompletionEvent::Status(event)) => {
1027 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1028 }
1029 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1030 })
1031 })
1032 .boxed()
1033}
1034
1035fn usage_updated_event<T>(
1036 usage: Option<ModelRequestUsage>,
1037) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1038 futures::stream::iter(usage.map(|usage| {
1039 Ok(CloudCompletionEvent::Status(
1040 CompletionRequestStatus::UsageUpdated {
1041 amount: usage.amount as usize,
1042 limit: usage.limit,
1043 },
1044 ))
1045 }))
1046}
1047
1048fn tool_use_limit_reached_event<T>(
1049 tool_use_limit_reached: bool,
1050) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1051 futures::stream::iter(tool_use_limit_reached.then(|| {
1052 Ok(CloudCompletionEvent::Status(
1053 CompletionRequestStatus::ToolUseLimitReached,
1054 ))
1055 }))
1056}
1057
1058fn response_lines<T: DeserializeOwned>(
1059 response: Response<AsyncBody>,
1060 includes_status_messages: bool,
1061) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1062 futures::stream::try_unfold(
1063 (String::new(), BufReader::new(response.into_body())),
1064 move |(mut line, mut body)| async move {
1065 match body.read_line(&mut line).await {
1066 Ok(0) => Ok(None),
1067 Ok(_) => {
1068 let event = if includes_status_messages {
1069 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1070 } else {
1071 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1072 };
1073
1074 line.clear();
1075 Ok(Some((event, (line, body))))
1076 }
1077 Err(e) => Err(e.into()),
1078 }
1079 },
1080 )
1081}
1082
1083#[derive(IntoElement, RegisterComponent)]
1084struct ZedAiConfiguration {
1085 is_connected: bool,
1086 plan: Option<proto::Plan>,
1087 subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1088 eligible_for_trial: bool,
1089 has_accepted_terms_of_service: bool,
1090 accept_terms_of_service_in_progress: bool,
1091 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1092 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1093}
1094
1095impl RenderOnce for ZedAiConfiguration {
1096 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1097 const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1098
1099 let is_pro = self.plan == Some(proto::Plan::ZedPro);
1100 let subscription_text = match (self.plan, self.subscription_period) {
1101 (Some(proto::Plan::ZedPro), Some(_)) => {
1102 "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1103 }
1104 (Some(proto::Plan::ZedProTrial), Some(_)) => {
1105 "You have access to Zed's hosted LLMs through your Zed Pro trial."
1106 }
1107 (Some(proto::Plan::Free), Some(_)) => {
1108 "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1109 }
1110 _ => {
1111 if self.eligible_for_trial {
1112 "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1113 } else {
1114 "Subscribe for access to Zed's hosted LLMs."
1115 }
1116 }
1117 };
1118 let manage_subscription_buttons = if is_pro {
1119 h_flex().child(
1120 Button::new("manage_settings", "Manage Subscription")
1121 .style(ButtonStyle::Tinted(TintColor::Accent))
1122 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1123 )
1124 } else {
1125 h_flex()
1126 .gap_2()
1127 .child(
1128 Button::new("learn_more", "Learn more")
1129 .style(ButtonStyle::Subtle)
1130 .on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)),
1131 )
1132 .child(
1133 Button::new(
1134 "upgrade",
1135 if self.plan.is_none() && self.eligible_for_trial {
1136 "Start Trial"
1137 } else {
1138 "Upgrade"
1139 },
1140 )
1141 .style(ButtonStyle::Subtle)
1142 .color(Color::Accent)
1143 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1144 )
1145 };
1146
1147 if self.is_connected {
1148 v_flex()
1149 .gap_3()
1150 .w_full()
1151 .when(!self.has_accepted_terms_of_service, |this| {
1152 this.child(render_accept_terms(
1153 LanguageModelProviderTosView::Configuration,
1154 self.accept_terms_of_service_in_progress,
1155 {
1156 let callback = self.accept_terms_of_service_callback.clone();
1157 move |window, cx| (callback)(window, cx)
1158 },
1159 ))
1160 })
1161 .when(self.has_accepted_terms_of_service, |this| {
1162 this.child(subscription_text)
1163 .child(manage_subscription_buttons)
1164 })
1165 } else {
1166 v_flex()
1167 .gap_2()
1168 .child(Label::new("Use Zed AI to access hosted language models."))
1169 .child(
1170 Button::new("sign_in", "Sign In")
1171 .icon_color(Color::Muted)
1172 .icon(IconName::Github)
1173 .icon_position(IconPosition::Start)
1174 .on_click({
1175 let callback = self.sign_in_callback.clone();
1176 move |_, window, cx| (callback)(window, cx)
1177 }),
1178 )
1179 }
1180 }
1181}
1182
1183struct ConfigurationView {
1184 state: Entity<State>,
1185 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1186 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1187}
1188
1189impl ConfigurationView {
1190 fn new(state: Entity<State>) -> Self {
1191 let accept_terms_of_service_callback = Arc::new({
1192 let state = state.clone();
1193 move |_window: &mut Window, cx: &mut App| {
1194 state.update(cx, |state, cx| {
1195 state.accept_terms_of_service(cx);
1196 });
1197 }
1198 });
1199
1200 let sign_in_callback = Arc::new({
1201 let state = state.clone();
1202 move |_window: &mut Window, cx: &mut App| {
1203 state.update(cx, |state, cx| {
1204 state.authenticate(cx).detach_and_log_err(cx);
1205 });
1206 }
1207 });
1208
1209 Self {
1210 state,
1211 accept_terms_of_service_callback,
1212 sign_in_callback,
1213 }
1214 }
1215}
1216
1217impl Render for ConfigurationView {
1218 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1219 let state = self.state.read(cx);
1220 let user_store = state.user_store.read(cx);
1221
1222 ZedAiConfiguration {
1223 is_connected: !state.is_signed_out(),
1224 plan: user_store.current_plan(),
1225 subscription_period: user_store.subscription_period(),
1226 eligible_for_trial: user_store.trial_started_at().is_none(),
1227 has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
1228 accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
1229 accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
1230 sign_in_callback: self.sign_in_callback.clone(),
1231 }
1232 }
1233}
1234
1235impl Component for ZedAiConfiguration {
1236 fn scope() -> ComponentScope {
1237 ComponentScope::Agent
1238 }
1239
1240 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1241 fn configuration(
1242 is_connected: bool,
1243 plan: Option<proto::Plan>,
1244 eligible_for_trial: bool,
1245 has_accepted_terms_of_service: bool,
1246 ) -> AnyElement {
1247 ZedAiConfiguration {
1248 is_connected,
1249 plan,
1250 subscription_period: plan
1251 .is_some()
1252 .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1253 eligible_for_trial,
1254 has_accepted_terms_of_service,
1255 accept_terms_of_service_in_progress: false,
1256 accept_terms_of_service_callback: Arc::new(|_, _| {}),
1257 sign_in_callback: Arc::new(|_, _| {}),
1258 }
1259 .into_any_element()
1260 }
1261
1262 Some(
1263 v_flex()
1264 .p_4()
1265 .gap_4()
1266 .children(vec![
1267 single_example("Not connected", configuration(false, None, false, true)),
1268 single_example(
1269 "Accept Terms of Service",
1270 configuration(true, None, true, false),
1271 ),
1272 single_example(
1273 "No Plan - Not eligible for trial",
1274 configuration(true, None, false, true),
1275 ),
1276 single_example(
1277 "No Plan - Eligible for trial",
1278 configuration(true, None, true, true),
1279 ),
1280 single_example(
1281 "Free Plan",
1282 configuration(true, Some(proto::Plan::Free), true, true),
1283 ),
1284 single_example(
1285 "Zed Pro Trial Plan",
1286 configuration(true, Some(proto::Plan::ZedProTrial), true, true),
1287 ),
1288 single_example(
1289 "Zed Pro Plan",
1290 configuration(true, Some(proto::Plan::ZedPro), true, true),
1291 ),
1292 ])
1293 .into_any_element(),
1294 )
1295 }
1296}