1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use chrono::{DateTime, Utc};
4use client::{Client, ModelRequestUsage, UserStore, zed_urls};
5use futures::{
6 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
7};
8use google_ai::GoogleModelMode;
9use gpui::{
10 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
11};
12use http_client::http::{HeaderMap, HeaderValue};
13use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
14use language_model::{
15 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
16 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
17 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
18 LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
19 LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
20 ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
21};
22use proto::Plan;
23use release_channel::AppVersion;
24use schemars::JsonSchema;
25use serde::{Deserialize, Serialize, de::DeserializeOwned};
26use settings::SettingsStore;
27use smol::io::{AsyncReadExt, BufReader};
28use std::pin::Pin;
29use std::str::FromStr as _;
30use std::sync::Arc;
31use std::time::Duration;
32use thiserror::Error;
33use ui::{TintColor, prelude::*};
34use util::{ResultExt as _, maybe};
35use zed_llm_client::{
36 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
37 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
38 ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
39 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
40 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
41};
42
43use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
44use crate::provider::google::{GoogleEventMapper, into_google};
45use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
46
47const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
48const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
49
50#[derive(Default, Clone, Debug, PartialEq)]
51pub struct ZedDotDevSettings {
52 pub available_models: Vec<AvailableModel>,
53}
54
55#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
56#[serde(rename_all = "lowercase")]
57pub enum AvailableProvider {
58 Anthropic,
59 OpenAi,
60 Google,
61}
62
63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
64pub struct AvailableModel {
65 /// The provider of the language model.
66 pub provider: AvailableProvider,
67 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
68 pub name: String,
69 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
70 pub display_name: Option<String>,
71 /// The size of the context window, indicating the maximum number of tokens the model can process.
72 pub max_tokens: usize,
73 /// The maximum number of output tokens allowed by the model.
74 pub max_output_tokens: Option<u64>,
75 /// The maximum number of completion tokens allowed by the model (o1-* only)
76 pub max_completion_tokens: Option<u64>,
77 /// Override this model with a different Anthropic model for tool calls.
78 pub tool_override: Option<String>,
79 /// Indicates whether this custom model supports caching.
80 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
81 /// The default temperature to use for this model.
82 pub default_temperature: Option<f32>,
83 /// Any extra beta headers to provide when using the model.
84 #[serde(default)]
85 pub extra_beta_headers: Vec<String>,
86 /// The model's mode (e.g. thinking)
87 pub mode: Option<ModelMode>,
88}
89
90#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
91#[serde(tag = "type", rename_all = "lowercase")]
92pub enum ModelMode {
93 #[default]
94 Default,
95 Thinking {
96 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
97 budget_tokens: Option<u32>,
98 },
99}
100
101impl From<ModelMode> for AnthropicModelMode {
102 fn from(value: ModelMode) -> Self {
103 match value {
104 ModelMode::Default => AnthropicModelMode::Default,
105 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
106 }
107 }
108}
109
110pub struct CloudLanguageModelProvider {
111 client: Arc<Client>,
112 state: gpui::Entity<State>,
113 _maintain_client_status: Task<()>,
114}
115
116pub struct State {
117 client: Arc<Client>,
118 llm_api_token: LlmApiToken,
119 user_store: Entity<UserStore>,
120 status: client::Status,
121 accept_terms_of_service_task: Option<Task<Result<()>>>,
122 models: Vec<Arc<zed_llm_client::LanguageModel>>,
123 default_model: Option<Arc<zed_llm_client::LanguageModel>>,
124 default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
125 recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
126 _fetch_models_task: Task<()>,
127 _settings_subscription: Subscription,
128 _llm_token_subscription: Subscription,
129}
130
131impl State {
132 fn new(
133 client: Arc<Client>,
134 user_store: Entity<UserStore>,
135 status: client::Status,
136 cx: &mut Context<Self>,
137 ) -> Self {
138 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
139
140 Self {
141 client: client.clone(),
142 llm_api_token: LlmApiToken::default(),
143 user_store,
144 status,
145 accept_terms_of_service_task: None,
146 models: Vec::new(),
147 default_model: None,
148 default_fast_model: None,
149 recommended_models: Vec::new(),
150 _fetch_models_task: cx.spawn(async move |this, cx| {
151 maybe!(async move {
152 let (client, llm_api_token) = this
153 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
154
155 loop {
156 let status = this.read_with(cx, |this, _cx| this.status)?;
157 if matches!(status, client::Status::Connected { .. }) {
158 break;
159 }
160
161 cx.background_executor()
162 .timer(Duration::from_millis(100))
163 .await;
164 }
165
166 let response = Self::fetch_models(client, llm_api_token).await?;
167 cx.update(|cx| {
168 this.update(cx, |this, cx| {
169 let mut models = Vec::new();
170
171 for model in response.models {
172 models.push(Arc::new(model.clone()));
173
174 // Right now we represent thinking variants of models as separate models on the client,
175 // so we need to insert variants for any model that supports thinking.
176 if model.supports_thinking {
177 models.push(Arc::new(zed_llm_client::LanguageModel {
178 id: zed_llm_client::LanguageModelId(
179 format!("{}-thinking", model.id).into(),
180 ),
181 display_name: format!("{} Thinking", model.display_name),
182 ..model
183 }));
184 }
185 }
186
187 this.default_model = models
188 .iter()
189 .find(|model| model.id == response.default_model)
190 .cloned();
191 this.default_fast_model = models
192 .iter()
193 .find(|model| model.id == response.default_fast_model)
194 .cloned();
195 this.recommended_models = response
196 .recommended_models
197 .iter()
198 .filter_map(|id| models.iter().find(|model| &model.id == id))
199 .cloned()
200 .collect();
201 this.models = models;
202 cx.notify();
203 })
204 })??;
205
206 anyhow::Ok(())
207 })
208 .await
209 .context("failed to fetch Zed models")
210 .log_err();
211 }),
212 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
213 cx.notify();
214 }),
215 _llm_token_subscription: cx.subscribe(
216 &refresh_llm_token_listener,
217 |this, _listener, _event, cx| {
218 let client = this.client.clone();
219 let llm_api_token = this.llm_api_token.clone();
220 cx.spawn(async move |_this, _cx| {
221 llm_api_token.refresh(&client).await?;
222 anyhow::Ok(())
223 })
224 .detach_and_log_err(cx);
225 },
226 ),
227 }
228 }
229
230 fn is_signed_out(&self) -> bool {
231 self.status.is_signed_out()
232 }
233
234 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
235 let client = self.client.clone();
236 cx.spawn(async move |state, cx| {
237 client
238 .authenticate_and_connect(true, &cx)
239 .await
240 .into_response()?;
241 state.update(cx, |_, cx| cx.notify())
242 })
243 }
244
245 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
246 self.user_store
247 .read(cx)
248 .current_user_has_accepted_terms()
249 .unwrap_or(false)
250 }
251
252 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
253 let user_store = self.user_store.clone();
254 self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| {
255 let _ = user_store
256 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
257 .await;
258 this.update(cx, |this, cx| {
259 this.accept_terms_of_service_task = None;
260 cx.notify()
261 })
262 }));
263 }
264
265 async fn fetch_models(
266 client: Arc<Client>,
267 llm_api_token: LlmApiToken,
268 ) -> Result<ListModelsResponse> {
269 let http_client = &client.http_client();
270 let token = llm_api_token.acquire(&client).await?;
271
272 let request = http_client::Request::builder()
273 .method(Method::GET)
274 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
275 .header("Authorization", format!("Bearer {token}"))
276 .body(AsyncBody::empty())?;
277 let mut response = http_client
278 .send(request)
279 .await
280 .context("failed to send list models request")?;
281
282 if response.status().is_success() {
283 let mut body = String::new();
284 response.body_mut().read_to_string(&mut body).await?;
285 return Ok(serde_json::from_str(&body)?);
286 } else {
287 let mut body = String::new();
288 response.body_mut().read_to_string(&mut body).await?;
289 anyhow::bail!(
290 "error listing models.\nStatus: {:?}\nBody: {body}",
291 response.status(),
292 );
293 }
294 }
295}
296
297impl CloudLanguageModelProvider {
298 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
299 let mut status_rx = client.status();
300 let status = *status_rx.borrow();
301
302 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
303
304 let state_ref = state.downgrade();
305 let maintain_client_status = cx.spawn(async move |cx| {
306 while let Some(status) = status_rx.next().await {
307 if let Some(this) = state_ref.upgrade() {
308 _ = this.update(cx, |this, cx| {
309 if this.status != status {
310 this.status = status;
311 cx.notify();
312 }
313 });
314 } else {
315 break;
316 }
317 }
318 });
319
320 Self {
321 client,
322 state: state.clone(),
323 _maintain_client_status: maintain_client_status,
324 }
325 }
326
327 fn create_language_model(
328 &self,
329 model: Arc<zed_llm_client::LanguageModel>,
330 llm_api_token: LlmApiToken,
331 ) -> Arc<dyn LanguageModel> {
332 Arc::new(CloudLanguageModel {
333 id: LanguageModelId(SharedString::from(model.id.0.clone())),
334 model,
335 llm_api_token: llm_api_token.clone(),
336 client: self.client.clone(),
337 request_limiter: RateLimiter::new(4),
338 })
339 }
340}
341
342impl LanguageModelProviderState for CloudLanguageModelProvider {
343 type ObservableEntity = State;
344
345 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
346 Some(self.state.clone())
347 }
348}
349
350impl LanguageModelProvider for CloudLanguageModelProvider {
351 fn id(&self) -> LanguageModelProviderId {
352 PROVIDER_ID
353 }
354
355 fn name(&self) -> LanguageModelProviderName {
356 PROVIDER_NAME
357 }
358
359 fn icon(&self) -> IconName {
360 IconName::AiZed
361 }
362
363 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
364 let default_model = self.state.read(cx).default_model.clone()?;
365 let llm_api_token = self.state.read(cx).llm_api_token.clone();
366 Some(self.create_language_model(default_model, llm_api_token))
367 }
368
369 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
370 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
371 let llm_api_token = self.state.read(cx).llm_api_token.clone();
372 Some(self.create_language_model(default_fast_model, llm_api_token))
373 }
374
375 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
376 let llm_api_token = self.state.read(cx).llm_api_token.clone();
377 self.state
378 .read(cx)
379 .recommended_models
380 .iter()
381 .cloned()
382 .map(|model| self.create_language_model(model, llm_api_token.clone()))
383 .collect()
384 }
385
386 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
387 let llm_api_token = self.state.read(cx).llm_api_token.clone();
388 self.state
389 .read(cx)
390 .models
391 .iter()
392 .cloned()
393 .map(|model| self.create_language_model(model, llm_api_token.clone()))
394 .collect()
395 }
396
397 fn is_authenticated(&self, cx: &App) -> bool {
398 let state = self.state.read(cx);
399 !state.is_signed_out() && state.has_accepted_terms_of_service(cx)
400 }
401
402 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
403 Task::ready(Ok(()))
404 }
405
406 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
407 cx.new(|_| ConfigurationView::new(self.state.clone()))
408 .into()
409 }
410
411 fn must_accept_terms(&self, cx: &App) -> bool {
412 !self.state.read(cx).has_accepted_terms_of_service(cx)
413 }
414
415 fn render_accept_terms(
416 &self,
417 view: LanguageModelProviderTosView,
418 cx: &mut App,
419 ) -> Option<AnyElement> {
420 let state = self.state.read(cx);
421 if state.has_accepted_terms_of_service(cx) {
422 return None;
423 }
424 Some(
425 render_accept_terms(view, state.accept_terms_of_service_task.is_some(), {
426 let state = self.state.clone();
427 move |_window, cx| {
428 state.update(cx, |state, cx| state.accept_terms_of_service(cx));
429 }
430 })
431 .into_any_element(),
432 )
433 }
434
435 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
436 Task::ready(Ok(()))
437 }
438}
439
440fn render_accept_terms(
441 view_kind: LanguageModelProviderTosView,
442 accept_terms_of_service_in_progress: bool,
443 accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static,
444) -> impl IntoElement {
445 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
446 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState);
447
448 let terms_button = Button::new("terms_of_service", "Terms of Service")
449 .style(ButtonStyle::Subtle)
450 .icon(IconName::ArrowUpRight)
451 .icon_color(Color::Muted)
452 .icon_size(IconSize::XSmall)
453 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
454 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
455
456 let button_container = h_flex().child(
457 Button::new("accept_terms", "I accept the Terms of Service")
458 .when(!thread_empty_state, |this| {
459 this.full_width()
460 .style(ButtonStyle::Tinted(TintColor::Accent))
461 .icon(IconName::Check)
462 .icon_position(IconPosition::Start)
463 .icon_size(IconSize::Small)
464 })
465 .when(thread_empty_state, |this| {
466 this.style(ButtonStyle::Tinted(TintColor::Warning))
467 .label_size(LabelSize::Small)
468 })
469 .disabled(accept_terms_of_service_in_progress)
470 .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)),
471 );
472
473 if thread_empty_state {
474 h_flex()
475 .w_full()
476 .flex_wrap()
477 .justify_between()
478 .child(
479 h_flex()
480 .child(
481 Label::new("To start using Zed AI, please read and accept the")
482 .size(LabelSize::Small),
483 )
484 .child(terms_button),
485 )
486 .child(button_container)
487 } else {
488 v_flex()
489 .w_full()
490 .gap_2()
491 .child(
492 h_flex()
493 .flex_wrap()
494 .when(thread_fresh_start, |this| this.justify_center())
495 .child(Label::new(
496 "To start using Zed AI, please read and accept the",
497 ))
498 .child(terms_button),
499 )
500 .child({
501 match view_kind {
502 LanguageModelProviderTosView::PromptEditorPopup => {
503 button_container.w_full().justify_end()
504 }
505 LanguageModelProviderTosView::Configuration => {
506 button_container.w_full().justify_start()
507 }
508 LanguageModelProviderTosView::ThreadFreshStart => {
509 button_container.w_full().justify_center()
510 }
511 LanguageModelProviderTosView::ThreadEmptyState => div().w_0(),
512 }
513 })
514 }
515}
516
517pub struct CloudLanguageModel {
518 id: LanguageModelId,
519 model: Arc<zed_llm_client::LanguageModel>,
520 llm_api_token: LlmApiToken,
521 client: Arc<Client>,
522 request_limiter: RateLimiter,
523}
524
525struct PerformLlmCompletionResponse {
526 response: Response<AsyncBody>,
527 usage: Option<ModelRequestUsage>,
528 tool_use_limit_reached: bool,
529 includes_status_messages: bool,
530}
531
532impl CloudLanguageModel {
533 async fn perform_llm_completion(
534 client: Arc<Client>,
535 llm_api_token: LlmApiToken,
536 app_version: Option<SemanticVersion>,
537 body: CompletionBody,
538 ) -> Result<PerformLlmCompletionResponse> {
539 let http_client = &client.http_client();
540
541 let mut token = llm_api_token.acquire(&client).await?;
542 let mut refreshed_token = false;
543
544 loop {
545 let request_builder = http_client::Request::builder()
546 .method(Method::POST)
547 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
548 let request_builder = if let Some(app_version) = app_version {
549 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
550 } else {
551 request_builder
552 };
553
554 let request = request_builder
555 .header("Content-Type", "application/json")
556 .header("Authorization", format!("Bearer {token}"))
557 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
558 .body(serde_json::to_string(&body)?.into())?;
559 let mut response = http_client.send(request).await?;
560 let status = response.status();
561 if status.is_success() {
562 let includes_status_messages = response
563 .headers()
564 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
565 .is_some();
566
567 let tool_use_limit_reached = response
568 .headers()
569 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
570 .is_some();
571
572 let usage = if includes_status_messages {
573 None
574 } else {
575 ModelRequestUsage::from_headers(response.headers()).ok()
576 };
577
578 return Ok(PerformLlmCompletionResponse {
579 response,
580 usage,
581 includes_status_messages,
582 tool_use_limit_reached,
583 });
584 }
585
586 if !refreshed_token
587 && response
588 .headers()
589 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
590 .is_some()
591 {
592 token = llm_api_token.refresh(&client).await?;
593 refreshed_token = true;
594 continue;
595 }
596
597 if status == StatusCode::FORBIDDEN
598 && response
599 .headers()
600 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
601 .is_some()
602 {
603 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
604 .headers()
605 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
606 .and_then(|resource| resource.to_str().ok())
607 {
608 if let Some(plan) = response
609 .headers()
610 .get(CURRENT_PLAN_HEADER_NAME)
611 .and_then(|plan| plan.to_str().ok())
612 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
613 {
614 let plan = match plan {
615 zed_llm_client::Plan::ZedFree => Plan::Free,
616 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
617 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
618 };
619 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
620 }
621 }
622 } else if status == StatusCode::PAYMENT_REQUIRED {
623 return Err(anyhow!(PaymentRequiredError));
624 }
625
626 let mut body = String::new();
627 let headers = response.headers().clone();
628 response.body_mut().read_to_string(&mut body).await?;
629 return Err(anyhow!(ApiError {
630 status,
631 body,
632 headers
633 }));
634 }
635 }
636}
637
638#[derive(Debug, Error)]
639#[error("cloud language model request failed with status {status}: {body}")]
640struct ApiError {
641 status: StatusCode,
642 body: String,
643 headers: HeaderMap<HeaderValue>,
644}
645
646impl From<ApiError> for LanguageModelCompletionError {
647 fn from(error: ApiError) -> Self {
648 let retry_after = None;
649 LanguageModelCompletionError::from_http_status(
650 PROVIDER_NAME,
651 error.status,
652 error.body,
653 retry_after,
654 )
655 }
656}
657
658impl LanguageModel for CloudLanguageModel {
659 fn id(&self) -> LanguageModelId {
660 self.id.clone()
661 }
662
663 fn name(&self) -> LanguageModelName {
664 LanguageModelName::from(self.model.display_name.clone())
665 }
666
667 fn provider_id(&self) -> LanguageModelProviderId {
668 PROVIDER_ID
669 }
670
671 fn provider_name(&self) -> LanguageModelProviderName {
672 PROVIDER_NAME
673 }
674
675 fn upstream_provider_id(&self) -> LanguageModelProviderId {
676 use zed_llm_client::LanguageModelProvider::*;
677 match self.model.provider {
678 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
679 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
680 Google => language_model::GOOGLE_PROVIDER_ID,
681 }
682 }
683
684 fn upstream_provider_name(&self) -> LanguageModelProviderName {
685 use zed_llm_client::LanguageModelProvider::*;
686 match self.model.provider {
687 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
688 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
689 Google => language_model::GOOGLE_PROVIDER_NAME,
690 }
691 }
692
693 fn supports_tools(&self) -> bool {
694 self.model.supports_tools
695 }
696
697 fn supports_images(&self) -> bool {
698 self.model.supports_images
699 }
700
701 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
702 match choice {
703 LanguageModelToolChoice::Auto
704 | LanguageModelToolChoice::Any
705 | LanguageModelToolChoice::None => true,
706 }
707 }
708
709 fn supports_burn_mode(&self) -> bool {
710 self.model.supports_max_mode
711 }
712
713 fn telemetry_id(&self) -> String {
714 format!("zed.dev/{}", self.model.id)
715 }
716
717 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
718 match self.model.provider {
719 zed_llm_client::LanguageModelProvider::Anthropic
720 | zed_llm_client::LanguageModelProvider::OpenAi => {
721 LanguageModelToolSchemaFormat::JsonSchema
722 }
723 zed_llm_client::LanguageModelProvider::Google => {
724 LanguageModelToolSchemaFormat::JsonSchemaSubset
725 }
726 }
727 }
728
729 fn max_token_count(&self) -> u64 {
730 self.model.max_token_count as u64
731 }
732
733 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
734 self.model
735 .max_token_count_in_max_mode
736 .filter(|_| self.model.supports_max_mode)
737 .map(|max_token_count| max_token_count as u64)
738 }
739
740 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
741 match &self.model.provider {
742 zed_llm_client::LanguageModelProvider::Anthropic => {
743 Some(LanguageModelCacheConfiguration {
744 min_total_token: 2_048,
745 should_speculate: true,
746 max_cache_anchors: 4,
747 })
748 }
749 zed_llm_client::LanguageModelProvider::OpenAi
750 | zed_llm_client::LanguageModelProvider::Google => None,
751 }
752 }
753
754 fn count_tokens(
755 &self,
756 request: LanguageModelRequest,
757 cx: &App,
758 ) -> BoxFuture<'static, Result<u64>> {
759 match self.model.provider {
760 zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
761 zed_llm_client::LanguageModelProvider::OpenAi => {
762 let model = match open_ai::Model::from_id(&self.model.id.0) {
763 Ok(model) => model,
764 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
765 };
766 count_open_ai_tokens(request, model, cx)
767 }
768 zed_llm_client::LanguageModelProvider::Google => {
769 let client = self.client.clone();
770 let llm_api_token = self.llm_api_token.clone();
771 let model_id = self.model.id.to_string();
772 let generate_content_request =
773 into_google(request, model_id.clone(), GoogleModelMode::Default);
774 async move {
775 let http_client = &client.http_client();
776 let token = llm_api_token.acquire(&client).await?;
777
778 let request_body = CountTokensBody {
779 provider: zed_llm_client::LanguageModelProvider::Google,
780 model: model_id,
781 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
782 generate_content_request,
783 })?,
784 };
785 let request = http_client::Request::builder()
786 .method(Method::POST)
787 .uri(
788 http_client
789 .build_zed_llm_url("/count_tokens", &[])?
790 .as_ref(),
791 )
792 .header("Content-Type", "application/json")
793 .header("Authorization", format!("Bearer {token}"))
794 .body(serde_json::to_string(&request_body)?.into())?;
795 let mut response = http_client.send(request).await?;
796 let status = response.status();
797 let headers = response.headers().clone();
798 let mut response_body = String::new();
799 response
800 .body_mut()
801 .read_to_string(&mut response_body)
802 .await?;
803
804 if status.is_success() {
805 let response_body: CountTokensResponse =
806 serde_json::from_str(&response_body)?;
807
808 Ok(response_body.tokens as u64)
809 } else {
810 Err(anyhow!(ApiError {
811 status,
812 body: response_body,
813 headers
814 }))
815 }
816 }
817 .boxed()
818 }
819 }
820 }
821
822 fn stream_completion(
823 &self,
824 request: LanguageModelRequest,
825 cx: &AsyncApp,
826 ) -> BoxFuture<
827 'static,
828 Result<
829 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
830 LanguageModelCompletionError,
831 >,
832 > {
833 let thread_id = request.thread_id.clone();
834 let prompt_id = request.prompt_id.clone();
835 let intent = request.intent;
836 let mode = request.mode;
837 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
838 match self.model.provider {
839 zed_llm_client::LanguageModelProvider::Anthropic => {
840 let request = into_anthropic(
841 request,
842 self.model.id.to_string(),
843 1.0,
844 self.model.max_output_tokens as u64,
845 if self.model.id.0.ends_with("-thinking") {
846 AnthropicModelMode::Thinking {
847 budget_tokens: Some(4_096),
848 }
849 } else {
850 AnthropicModelMode::Default
851 },
852 );
853 let client = self.client.clone();
854 let llm_api_token = self.llm_api_token.clone();
855 let future = self.request_limiter.stream(async move {
856 let PerformLlmCompletionResponse {
857 response,
858 usage,
859 includes_status_messages,
860 tool_use_limit_reached,
861 } = Self::perform_llm_completion(
862 client.clone(),
863 llm_api_token,
864 app_version,
865 CompletionBody {
866 thread_id,
867 prompt_id,
868 intent,
869 mode,
870 provider: zed_llm_client::LanguageModelProvider::Anthropic,
871 model: request.model.clone(),
872 provider_request: serde_json::to_value(&request)
873 .map_err(|e| anyhow!(e))?,
874 },
875 )
876 .await
877 .map_err(|err| match err.downcast::<ApiError>() {
878 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
879 Err(err) => anyhow!(err),
880 })?;
881
882 let mut mapper = AnthropicEventMapper::new();
883 Ok(map_cloud_completion_events(
884 Box::pin(
885 response_lines(response, includes_status_messages)
886 .chain(usage_updated_event(usage))
887 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
888 ),
889 move |event| mapper.map_event(event),
890 ))
891 });
892 async move { Ok(future.await?.boxed()) }.boxed()
893 }
894 zed_llm_client::LanguageModelProvider::OpenAi => {
895 let client = self.client.clone();
896 let model = match open_ai::Model::from_id(&self.model.id.0) {
897 Ok(model) => model,
898 Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
899 };
900 let request = into_open_ai(
901 request,
902 model.id(),
903 model.supports_parallel_tool_calls(),
904 None,
905 );
906 let llm_api_token = self.llm_api_token.clone();
907 let future = self.request_limiter.stream(async move {
908 let PerformLlmCompletionResponse {
909 response,
910 usage,
911 includes_status_messages,
912 tool_use_limit_reached,
913 } = Self::perform_llm_completion(
914 client.clone(),
915 llm_api_token,
916 app_version,
917 CompletionBody {
918 thread_id,
919 prompt_id,
920 intent,
921 mode,
922 provider: zed_llm_client::LanguageModelProvider::OpenAi,
923 model: request.model.clone(),
924 provider_request: serde_json::to_value(&request)
925 .map_err(|e| anyhow!(e))?,
926 },
927 )
928 .await?;
929
930 let mut mapper = OpenAiEventMapper::new();
931 Ok(map_cloud_completion_events(
932 Box::pin(
933 response_lines(response, includes_status_messages)
934 .chain(usage_updated_event(usage))
935 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
936 ),
937 move |event| mapper.map_event(event),
938 ))
939 });
940 async move { Ok(future.await?.boxed()) }.boxed()
941 }
942 zed_llm_client::LanguageModelProvider::Google => {
943 let client = self.client.clone();
944 let request =
945 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
946 let llm_api_token = self.llm_api_token.clone();
947 let future = self.request_limiter.stream(async move {
948 let PerformLlmCompletionResponse {
949 response,
950 usage,
951 includes_status_messages,
952 tool_use_limit_reached,
953 } = Self::perform_llm_completion(
954 client.clone(),
955 llm_api_token,
956 app_version,
957 CompletionBody {
958 thread_id,
959 prompt_id,
960 intent,
961 mode,
962 provider: zed_llm_client::LanguageModelProvider::Google,
963 model: request.model.model_id.clone(),
964 provider_request: serde_json::to_value(&request)
965 .map_err(|e| anyhow!(e))?,
966 },
967 )
968 .await?;
969
970 let mut mapper = GoogleEventMapper::new();
971 Ok(map_cloud_completion_events(
972 Box::pin(
973 response_lines(response, includes_status_messages)
974 .chain(usage_updated_event(usage))
975 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
976 ),
977 move |event| mapper.map_event(event),
978 ))
979 });
980 async move { Ok(future.await?.boxed()) }.boxed()
981 }
982 }
983 }
984}
985
986#[derive(Serialize, Deserialize)]
987#[serde(rename_all = "snake_case")]
988pub enum CloudCompletionEvent<T> {
989 Status(CompletionRequestStatus),
990 Event(T),
991}
992
993fn map_cloud_completion_events<T, F>(
994 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
995 mut map_callback: F,
996) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
997where
998 T: DeserializeOwned + 'static,
999 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1000 + Send
1001 + 'static,
1002{
1003 stream
1004 .flat_map(move |event| {
1005 futures::stream::iter(match event {
1006 Err(error) => {
1007 vec![Err(LanguageModelCompletionError::from(error))]
1008 }
1009 Ok(CloudCompletionEvent::Status(event)) => {
1010 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1011 }
1012 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1013 })
1014 })
1015 .boxed()
1016}
1017
1018fn usage_updated_event<T>(
1019 usage: Option<ModelRequestUsage>,
1020) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1021 futures::stream::iter(usage.map(|usage| {
1022 Ok(CloudCompletionEvent::Status(
1023 CompletionRequestStatus::UsageUpdated {
1024 amount: usage.amount as usize,
1025 limit: usage.limit,
1026 },
1027 ))
1028 }))
1029}
1030
1031fn tool_use_limit_reached_event<T>(
1032 tool_use_limit_reached: bool,
1033) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1034 futures::stream::iter(tool_use_limit_reached.then(|| {
1035 Ok(CloudCompletionEvent::Status(
1036 CompletionRequestStatus::ToolUseLimitReached,
1037 ))
1038 }))
1039}
1040
1041fn response_lines<T: DeserializeOwned>(
1042 response: Response<AsyncBody>,
1043 includes_status_messages: bool,
1044) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1045 futures::stream::try_unfold(
1046 (String::new(), BufReader::new(response.into_body())),
1047 move |(mut line, mut body)| async move {
1048 match body.read_line(&mut line).await {
1049 Ok(0) => Ok(None),
1050 Ok(_) => {
1051 let event = if includes_status_messages {
1052 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1053 } else {
1054 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1055 };
1056
1057 line.clear();
1058 Ok(Some((event, (line, body))))
1059 }
1060 Err(e) => Err(e.into()),
1061 }
1062 },
1063 )
1064}
1065
1066#[derive(IntoElement, RegisterComponent)]
1067struct ZedAiConfiguration {
1068 is_connected: bool,
1069 plan: Option<proto::Plan>,
1070 subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1071 eligible_for_trial: bool,
1072 has_accepted_terms_of_service: bool,
1073 accept_terms_of_service_in_progress: bool,
1074 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1075 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1076}
1077
1078impl RenderOnce for ZedAiConfiguration {
1079 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1080 const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1081
1082 let is_pro = self.plan == Some(proto::Plan::ZedPro);
1083 let subscription_text = match (self.plan, self.subscription_period) {
1084 (Some(proto::Plan::ZedPro), Some(_)) => {
1085 "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1086 }
1087 (Some(proto::Plan::ZedProTrial), Some(_)) => {
1088 "You have access to Zed's hosted LLMs through your Zed Pro trial."
1089 }
1090 (Some(proto::Plan::Free), Some(_)) => {
1091 "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1092 }
1093 _ => {
1094 if self.eligible_for_trial {
1095 "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1096 } else {
1097 "Subscribe for access to Zed's hosted LLMs."
1098 }
1099 }
1100 };
1101 let manage_subscription_buttons = if is_pro {
1102 h_flex().child(
1103 Button::new("manage_settings", "Manage Subscription")
1104 .style(ButtonStyle::Tinted(TintColor::Accent))
1105 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1106 )
1107 } else {
1108 h_flex()
1109 .gap_2()
1110 .child(
1111 Button::new("learn_more", "Learn more")
1112 .style(ButtonStyle::Subtle)
1113 .on_click(|_, _, cx| cx.open_url(ZED_PRICING_URL)),
1114 )
1115 .child(
1116 Button::new(
1117 "upgrade",
1118 if self.plan.is_none() && self.eligible_for_trial {
1119 "Start Trial"
1120 } else {
1121 "Upgrade"
1122 },
1123 )
1124 .style(ButtonStyle::Subtle)
1125 .color(Color::Accent)
1126 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1127 )
1128 };
1129
1130 if self.is_connected {
1131 v_flex()
1132 .gap_3()
1133 .w_full()
1134 .when(!self.has_accepted_terms_of_service, |this| {
1135 this.child(render_accept_terms(
1136 LanguageModelProviderTosView::Configuration,
1137 self.accept_terms_of_service_in_progress,
1138 {
1139 let callback = self.accept_terms_of_service_callback.clone();
1140 move |window, cx| (callback)(window, cx)
1141 },
1142 ))
1143 })
1144 .when(self.has_accepted_terms_of_service, |this| {
1145 this.child(subscription_text)
1146 .child(manage_subscription_buttons)
1147 })
1148 } else {
1149 v_flex()
1150 .gap_2()
1151 .child(Label::new("Use Zed AI to access hosted language models."))
1152 .child(
1153 Button::new("sign_in", "Sign In")
1154 .icon_color(Color::Muted)
1155 .icon(IconName::Github)
1156 .icon_position(IconPosition::Start)
1157 .on_click({
1158 let callback = self.sign_in_callback.clone();
1159 move |_, window, cx| (callback)(window, cx)
1160 }),
1161 )
1162 }
1163 }
1164}
1165
1166struct ConfigurationView {
1167 state: Entity<State>,
1168 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1169 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1170}
1171
1172impl ConfigurationView {
1173 fn new(state: Entity<State>) -> Self {
1174 let accept_terms_of_service_callback = Arc::new({
1175 let state = state.clone();
1176 move |_window: &mut Window, cx: &mut App| {
1177 state.update(cx, |state, cx| {
1178 state.accept_terms_of_service(cx);
1179 });
1180 }
1181 });
1182
1183 let sign_in_callback = Arc::new({
1184 let state = state.clone();
1185 move |_window: &mut Window, cx: &mut App| {
1186 state.update(cx, |state, cx| {
1187 state.authenticate(cx).detach_and_log_err(cx);
1188 });
1189 }
1190 });
1191
1192 Self {
1193 state,
1194 accept_terms_of_service_callback,
1195 sign_in_callback,
1196 }
1197 }
1198}
1199
1200impl Render for ConfigurationView {
1201 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1202 let state = self.state.read(cx);
1203 let user_store = state.user_store.read(cx);
1204
1205 ZedAiConfiguration {
1206 is_connected: !state.is_signed_out(),
1207 plan: user_store.current_plan(),
1208 subscription_period: user_store.subscription_period(),
1209 eligible_for_trial: user_store.trial_started_at().is_none(),
1210 has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
1211 accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
1212 accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
1213 sign_in_callback: self.sign_in_callback.clone(),
1214 }
1215 }
1216}
1217
1218impl Component for ZedAiConfiguration {
1219 fn scope() -> ComponentScope {
1220 ComponentScope::Agent
1221 }
1222
1223 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1224 fn configuration(
1225 is_connected: bool,
1226 plan: Option<proto::Plan>,
1227 eligible_for_trial: bool,
1228 has_accepted_terms_of_service: bool,
1229 ) -> AnyElement {
1230 ZedAiConfiguration {
1231 is_connected,
1232 plan,
1233 subscription_period: plan
1234 .is_some()
1235 .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1236 eligible_for_trial,
1237 has_accepted_terms_of_service,
1238 accept_terms_of_service_in_progress: false,
1239 accept_terms_of_service_callback: Arc::new(|_, _| {}),
1240 sign_in_callback: Arc::new(|_, _| {}),
1241 }
1242 .into_any_element()
1243 }
1244
1245 Some(
1246 v_flex()
1247 .p_4()
1248 .gap_4()
1249 .children(vec![
1250 single_example("Not connected", configuration(false, None, false, true)),
1251 single_example(
1252 "Accept Terms of Service",
1253 configuration(true, None, true, false),
1254 ),
1255 single_example(
1256 "No Plan - Not eligible for trial",
1257 configuration(true, None, false, true),
1258 ),
1259 single_example(
1260 "No Plan - Eligible for trial",
1261 configuration(true, None, true, true),
1262 ),
1263 single_example(
1264 "Free Plan",
1265 configuration(true, Some(proto::Plan::Free), true, true),
1266 ),
1267 single_example(
1268 "Zed Pro Trial Plan",
1269 configuration(true, Some(proto::Plan::ZedProTrial), true, true),
1270 ),
1271 single_example(
1272 "Zed Pro Plan",
1273 configuration(true, Some(proto::Plan::ZedPro), true, true),
1274 ),
1275 ])
1276 .into_any_element(),
1277 )
1278 }
1279}