1use anthropic::{AnthropicModelMode, parse_prompt_too_long};
2use anyhow::{Context as _, Result, anyhow};
3use client::{Client, ModelRequestUsage, UserStore, zed_urls};
4use futures::{
5 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
6};
7use google_ai::GoogleModelMode;
8use gpui::{
9 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
10};
11use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
12use language_model::{
13 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
14 LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
15 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
16 LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
17 LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
18 ZED_CLOUD_PROVIDER_ID,
19};
20use language_model::{
21 LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
22 RefreshLlmTokenListener,
23};
24use proto::Plan;
25use release_channel::AppVersion;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use settings::SettingsStore;
29use smol::Timer;
30use smol::io::{AsyncReadExt, BufReader};
31use std::pin::Pin;
32use std::str::FromStr as _;
33use std::sync::Arc;
34use std::time::Duration;
35use thiserror::Error;
36use ui::{TintColor, prelude::*};
37use util::{ResultExt as _, maybe};
38use zed_llm_client::{
39 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
40 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
41 ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
42 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
43 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
44};
45
46use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
47use crate::provider::google::{GoogleEventMapper, into_google};
48use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
49
50pub const PROVIDER_NAME: &str = "Zed";
51
52#[derive(Default, Clone, Debug, PartialEq)]
53pub struct ZedDotDevSettings {
54 pub available_models: Vec<AvailableModel>,
55}
56
57#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
58#[serde(rename_all = "lowercase")]
59pub enum AvailableProvider {
60 Anthropic,
61 OpenAi,
62 Google,
63}
64
65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
66pub struct AvailableModel {
67 /// The provider of the language model.
68 pub provider: AvailableProvider,
69 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
70 pub name: String,
71 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
72 pub display_name: Option<String>,
73 /// The size of the context window, indicating the maximum number of tokens the model can process.
74 pub max_tokens: usize,
75 /// The maximum number of output tokens allowed by the model.
76 pub max_output_tokens: Option<u64>,
77 /// The maximum number of completion tokens allowed by the model (o1-* only)
78 pub max_completion_tokens: Option<u64>,
79 /// Override this model with a different Anthropic model for tool calls.
80 pub tool_override: Option<String>,
81 /// Indicates whether this custom model supports caching.
82 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
83 /// The default temperature to use for this model.
84 pub default_temperature: Option<f32>,
85 /// Any extra beta headers to provide when using the model.
86 #[serde(default)]
87 pub extra_beta_headers: Vec<String>,
88 /// The model's mode (e.g. thinking)
89 pub mode: Option<ModelMode>,
90}
91
92#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
93#[serde(tag = "type", rename_all = "lowercase")]
94pub enum ModelMode {
95 #[default]
96 Default,
97 Thinking {
98 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
99 budget_tokens: Option<u32>,
100 },
101}
102
103impl From<ModelMode> for AnthropicModelMode {
104 fn from(value: ModelMode) -> Self {
105 match value {
106 ModelMode::Default => AnthropicModelMode::Default,
107 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
108 }
109 }
110}
111
112pub struct CloudLanguageModelProvider {
113 client: Arc<Client>,
114 state: gpui::Entity<State>,
115 _maintain_client_status: Task<()>,
116}
117
118pub struct State {
119 client: Arc<Client>,
120 llm_api_token: LlmApiToken,
121 user_store: Entity<UserStore>,
122 status: client::Status,
123 accept_terms: Option<Task<Result<()>>>,
124 models: Vec<Arc<zed_llm_client::LanguageModel>>,
125 default_model: Option<Arc<zed_llm_client::LanguageModel>>,
126 default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
127 recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
128 _fetch_models_task: Task<()>,
129 _settings_subscription: Subscription,
130 _llm_token_subscription: Subscription,
131}
132
133impl State {
134 fn new(
135 client: Arc<Client>,
136 user_store: Entity<UserStore>,
137 status: client::Status,
138 cx: &mut Context<Self>,
139 ) -> Self {
140 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
141
142 Self {
143 client: client.clone(),
144 llm_api_token: LlmApiToken::default(),
145 user_store,
146 status,
147 accept_terms: None,
148 models: Vec::new(),
149 default_model: None,
150 default_fast_model: None,
151 recommended_models: Vec::new(),
152 _fetch_models_task: cx.spawn(async move |this, cx| {
153 maybe!(async move {
154 let (client, llm_api_token) = this
155 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
156
157 loop {
158 let status = this.read_with(cx, |this, _cx| this.status)?;
159 if matches!(status, client::Status::Connected { .. }) {
160 break;
161 }
162
163 cx.background_executor()
164 .timer(Duration::from_millis(100))
165 .await;
166 }
167
168 let response = Self::fetch_models(client, llm_api_token).await?;
169 cx.update(|cx| {
170 this.update(cx, |this, cx| {
171 let mut models = Vec::new();
172
173 for model in response.models {
174 models.push(Arc::new(model.clone()));
175
176 // Right now we represent thinking variants of models as separate models on the client,
177 // so we need to insert variants for any model that supports thinking.
178 if model.supports_thinking {
179 models.push(Arc::new(zed_llm_client::LanguageModel {
180 id: zed_llm_client::LanguageModelId(
181 format!("{}-thinking", model.id).into(),
182 ),
183 display_name: format!("{} Thinking", model.display_name),
184 ..model
185 }));
186 }
187 }
188
189 this.default_model = models
190 .iter()
191 .find(|model| model.id == response.default_model)
192 .cloned();
193 this.default_fast_model = models
194 .iter()
195 .find(|model| model.id == response.default_fast_model)
196 .cloned();
197 this.recommended_models = response
198 .recommended_models
199 .iter()
200 .filter_map(|id| models.iter().find(|model| &model.id == id))
201 .cloned()
202 .collect();
203 this.models = models;
204 cx.notify();
205 })
206 })??;
207
208 anyhow::Ok(())
209 })
210 .await
211 .context("failed to fetch Zed models")
212 .log_err();
213 }),
214 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
215 cx.notify();
216 }),
217 _llm_token_subscription: cx.subscribe(
218 &refresh_llm_token_listener,
219 |this, _listener, _event, cx| {
220 let client = this.client.clone();
221 let llm_api_token = this.llm_api_token.clone();
222 cx.spawn(async move |_this, _cx| {
223 llm_api_token.refresh(&client).await?;
224 anyhow::Ok(())
225 })
226 .detach_and_log_err(cx);
227 },
228 ),
229 }
230 }
231
232 fn is_signed_out(&self) -> bool {
233 self.status.is_signed_out()
234 }
235
236 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
237 let client = self.client.clone();
238 cx.spawn(async move |state, cx| {
239 client
240 .authenticate_and_connect(true, &cx)
241 .await
242 .into_response()?;
243 state.update(cx, |_, cx| cx.notify())
244 })
245 }
246
247 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
248 self.user_store
249 .read(cx)
250 .current_user_has_accepted_terms()
251 .unwrap_or(false)
252 }
253
254 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
255 let user_store = self.user_store.clone();
256 self.accept_terms = Some(cx.spawn(async move |this, cx| {
257 let _ = user_store
258 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
259 .await;
260 this.update(cx, |this, cx| {
261 this.accept_terms = None;
262 cx.notify()
263 })
264 }));
265 }
266
267 async fn fetch_models(
268 client: Arc<Client>,
269 llm_api_token: LlmApiToken,
270 ) -> Result<ListModelsResponse> {
271 let http_client = &client.http_client();
272 let token = llm_api_token.acquire(&client).await?;
273
274 let request = http_client::Request::builder()
275 .method(Method::GET)
276 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
277 .header("Authorization", format!("Bearer {token}"))
278 .body(AsyncBody::empty())?;
279 let mut response = http_client
280 .send(request)
281 .await
282 .context("failed to send list models request")?;
283
284 if response.status().is_success() {
285 let mut body = String::new();
286 response.body_mut().read_to_string(&mut body).await?;
287 return Ok(serde_json::from_str(&body)?);
288 } else {
289 let mut body = String::new();
290 response.body_mut().read_to_string(&mut body).await?;
291 anyhow::bail!(
292 "error listing models.\nStatus: {:?}\nBody: {body}",
293 response.status(),
294 );
295 }
296 }
297}
298
299impl CloudLanguageModelProvider {
300 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
301 let mut status_rx = client.status();
302 let status = *status_rx.borrow();
303
304 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
305
306 let state_ref = state.downgrade();
307 let maintain_client_status = cx.spawn(async move |cx| {
308 while let Some(status) = status_rx.next().await {
309 if let Some(this) = state_ref.upgrade() {
310 _ = this.update(cx, |this, cx| {
311 if this.status != status {
312 this.status = status;
313 cx.notify();
314 }
315 });
316 } else {
317 break;
318 }
319 }
320 });
321
322 Self {
323 client,
324 state: state.clone(),
325 _maintain_client_status: maintain_client_status,
326 }
327 }
328
329 fn create_language_model(
330 &self,
331 model: Arc<zed_llm_client::LanguageModel>,
332 llm_api_token: LlmApiToken,
333 ) -> Arc<dyn LanguageModel> {
334 Arc::new(CloudLanguageModel {
335 id: LanguageModelId(SharedString::from(model.id.0.clone())),
336 model,
337 llm_api_token: llm_api_token.clone(),
338 client: self.client.clone(),
339 request_limiter: RateLimiter::new(4),
340 })
341 }
342}
343
344impl LanguageModelProviderState for CloudLanguageModelProvider {
345 type ObservableEntity = State;
346
347 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
348 Some(self.state.clone())
349 }
350}
351
352impl LanguageModelProvider for CloudLanguageModelProvider {
353 fn id(&self) -> LanguageModelProviderId {
354 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
355 }
356
357 fn name(&self) -> LanguageModelProviderName {
358 LanguageModelProviderName(PROVIDER_NAME.into())
359 }
360
361 fn icon(&self) -> IconName {
362 IconName::AiZed
363 }
364
365 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
366 let default_model = self.state.read(cx).default_model.clone()?;
367 let llm_api_token = self.state.read(cx).llm_api_token.clone();
368 Some(self.create_language_model(default_model, llm_api_token))
369 }
370
371 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
372 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
373 let llm_api_token = self.state.read(cx).llm_api_token.clone();
374 Some(self.create_language_model(default_fast_model, llm_api_token))
375 }
376
377 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
378 let llm_api_token = self.state.read(cx).llm_api_token.clone();
379 self.state
380 .read(cx)
381 .recommended_models
382 .iter()
383 .cloned()
384 .map(|model| self.create_language_model(model, llm_api_token.clone()))
385 .collect()
386 }
387
388 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
389 let llm_api_token = self.state.read(cx).llm_api_token.clone();
390 self.state
391 .read(cx)
392 .models
393 .iter()
394 .cloned()
395 .map(|model| self.create_language_model(model, llm_api_token.clone()))
396 .collect()
397 }
398
399 fn is_authenticated(&self, cx: &App) -> bool {
400 !self.state.read(cx).is_signed_out()
401 }
402
403 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
404 Task::ready(Ok(()))
405 }
406
407 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
408 cx.new(|_| ConfigurationView {
409 state: self.state.clone(),
410 })
411 .into()
412 }
413
414 fn must_accept_terms(&self, cx: &App) -> bool {
415 !self.state.read(cx).has_accepted_terms_of_service(cx)
416 }
417
418 fn render_accept_terms(
419 &self,
420 view: LanguageModelProviderTosView,
421 cx: &mut App,
422 ) -> Option<AnyElement> {
423 render_accept_terms(self.state.clone(), view, cx)
424 }
425
426 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
427 Task::ready(Ok(()))
428 }
429}
430
431fn render_accept_terms(
432 state: Entity<State>,
433 view_kind: LanguageModelProviderTosView,
434 cx: &mut App,
435) -> Option<AnyElement> {
436 if state.read(cx).has_accepted_terms_of_service(cx) {
437 return None;
438 }
439
440 let accept_terms_disabled = state.read(cx).accept_terms.is_some();
441
442 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
443 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
444
445 let terms_button = Button::new("terms_of_service", "Terms of Service")
446 .style(ButtonStyle::Subtle)
447 .icon(IconName::ArrowUpRight)
448 .icon_color(Color::Muted)
449 .icon_size(IconSize::XSmall)
450 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
451 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
452
453 let button_container = h_flex().child(
454 Button::new("accept_terms", "I accept the Terms of Service")
455 .when(!thread_empty_state, |this| {
456 this.full_width()
457 .style(ButtonStyle::Tinted(TintColor::Accent))
458 .icon(IconName::Check)
459 .icon_position(IconPosition::Start)
460 .icon_size(IconSize::Small)
461 })
462 .when(thread_empty_state, |this| {
463 this.style(ButtonStyle::Tinted(TintColor::Warning))
464 .label_size(LabelSize::Small)
465 })
466 .disabled(accept_terms_disabled)
467 .on_click({
468 let state = state.downgrade();
469 move |_, _window, cx| {
470 state
471 .update(cx, |state, cx| state.accept_terms_of_service(cx))
472 .ok();
473 }
474 }),
475 );
476
477 let form = if thread_empty_state {
478 h_flex()
479 .w_full()
480 .flex_wrap()
481 .justify_between()
482 .child(
483 h_flex()
484 .child(
485 Label::new("To start using Zed AI, please read and accept the")
486 .size(LabelSize::Small),
487 )
488 .child(terms_button),
489 )
490 .child(button_container)
491 } else {
492 v_flex()
493 .w_full()
494 .gap_2()
495 .child(
496 h_flex()
497 .flex_wrap()
498 .when(thread_fresh_start, |this| this.justify_center())
499 .child(Label::new(
500 "To start using Zed AI, please read and accept the",
501 ))
502 .child(terms_button),
503 )
504 .child({
505 match view_kind {
506 LanguageModelProviderTosView::PromptEditorPopup => {
507 button_container.w_full().justify_end()
508 }
509 LanguageModelProviderTosView::Configuration => {
510 button_container.w_full().justify_start()
511 }
512 LanguageModelProviderTosView::ThreadFreshStart => {
513 button_container.w_full().justify_center()
514 }
515 LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
516 }
517 })
518 };
519
520 Some(form.into_any())
521}
522
523pub struct CloudLanguageModel {
524 id: LanguageModelId,
525 model: Arc<zed_llm_client::LanguageModel>,
526 llm_api_token: LlmApiToken,
527 client: Arc<Client>,
528 request_limiter: RateLimiter,
529}
530
531struct PerformLlmCompletionResponse {
532 response: Response<AsyncBody>,
533 usage: Option<ModelRequestUsage>,
534 tool_use_limit_reached: bool,
535 includes_status_messages: bool,
536}
537
538impl CloudLanguageModel {
539 const MAX_RETRIES: usize = 3;
540
541 async fn perform_llm_completion(
542 client: Arc<Client>,
543 llm_api_token: LlmApiToken,
544 app_version: Option<SemanticVersion>,
545 body: CompletionBody,
546 ) -> Result<PerformLlmCompletionResponse> {
547 let http_client = &client.http_client();
548
549 let mut token = llm_api_token.acquire(&client).await?;
550 let mut retries_remaining = Self::MAX_RETRIES;
551 let mut retry_delay = Duration::from_secs(1);
552
553 loop {
554 let request_builder = http_client::Request::builder()
555 .method(Method::POST)
556 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
557 let request_builder = if let Some(app_version) = app_version {
558 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
559 } else {
560 request_builder
561 };
562
563 let request = request_builder
564 .header("Content-Type", "application/json")
565 .header("Authorization", format!("Bearer {token}"))
566 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
567 .body(serde_json::to_string(&body)?.into())?;
568 let mut response = http_client.send(request).await?;
569 let status = response.status();
570 if status.is_success() {
571 let includes_status_messages = response
572 .headers()
573 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
574 .is_some();
575
576 let tool_use_limit_reached = response
577 .headers()
578 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
579 .is_some();
580
581 let usage = if includes_status_messages {
582 None
583 } else {
584 ModelRequestUsage::from_headers(response.headers()).ok()
585 };
586
587 return Ok(PerformLlmCompletionResponse {
588 response,
589 usage,
590 includes_status_messages,
591 tool_use_limit_reached,
592 });
593 } else if response
594 .headers()
595 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
596 .is_some()
597 {
598 retries_remaining -= 1;
599 token = llm_api_token.refresh(&client).await?;
600 } else if status == StatusCode::FORBIDDEN
601 && response
602 .headers()
603 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
604 .is_some()
605 {
606 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
607 .headers()
608 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
609 .and_then(|resource| resource.to_str().ok())
610 {
611 if let Some(plan) = response
612 .headers()
613 .get(CURRENT_PLAN_HEADER_NAME)
614 .and_then(|plan| plan.to_str().ok())
615 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
616 {
617 let plan = match plan {
618 zed_llm_client::Plan::ZedFree => Plan::Free,
619 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
620 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
621 };
622 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
623 }
624 }
625
626 anyhow::bail!("Forbidden");
627 } else if status.as_u16() >= 500 && status.as_u16() < 600 {
628 // If we encounter an error in the 500 range, retry after a delay.
629 // We've seen at least these in the wild from API providers:
630 // * 500 Internal Server Error
631 // * 502 Bad Gateway
632 // * 529 Service Overloaded
633
634 if retries_remaining == 0 {
635 let mut body = String::new();
636 response.body_mut().read_to_string(&mut body).await?;
637 anyhow::bail!(
638 "cloud language model completion failed after {} retries with status {status}: {body}",
639 Self::MAX_RETRIES
640 );
641 }
642
643 Timer::after(retry_delay).await;
644
645 retries_remaining -= 1;
646 retry_delay *= 2; // If it fails again, wait longer.
647 } else if status == StatusCode::PAYMENT_REQUIRED {
648 return Err(anyhow!(PaymentRequiredError));
649 } else {
650 let mut body = String::new();
651 response.body_mut().read_to_string(&mut body).await?;
652 return Err(anyhow!(ApiError { status, body }));
653 }
654 }
655 }
656}
657
658#[derive(Debug, Error)]
659#[error("cloud language model request failed with status {status}: {body}")]
660struct ApiError {
661 status: StatusCode,
662 body: String,
663}
664
665impl LanguageModel for CloudLanguageModel {
666 fn id(&self) -> LanguageModelId {
667 self.id.clone()
668 }
669
670 fn name(&self) -> LanguageModelName {
671 LanguageModelName::from(self.model.display_name.clone())
672 }
673
674 fn provider_id(&self) -> LanguageModelProviderId {
675 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
676 }
677
678 fn provider_name(&self) -> LanguageModelProviderName {
679 LanguageModelProviderName(PROVIDER_NAME.into())
680 }
681
682 fn supports_tools(&self) -> bool {
683 self.model.supports_tools
684 }
685
686 fn supports_images(&self) -> bool {
687 self.model.supports_images
688 }
689
690 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
691 match choice {
692 LanguageModelToolChoice::Auto
693 | LanguageModelToolChoice::Any
694 | LanguageModelToolChoice::None => true,
695 }
696 }
697
698 fn supports_burn_mode(&self) -> bool {
699 self.model.supports_max_mode
700 }
701
702 fn telemetry_id(&self) -> String {
703 format!("zed.dev/{}", self.model.id)
704 }
705
706 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
707 match self.model.provider {
708 zed_llm_client::LanguageModelProvider::Anthropic
709 | zed_llm_client::LanguageModelProvider::OpenAi => {
710 LanguageModelToolSchemaFormat::JsonSchema
711 }
712 zed_llm_client::LanguageModelProvider::Google => {
713 LanguageModelToolSchemaFormat::JsonSchemaSubset
714 }
715 }
716 }
717
718 fn max_token_count(&self) -> u64 {
719 self.model.max_token_count as u64
720 }
721
722 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
723 match &self.model.provider {
724 zed_llm_client::LanguageModelProvider::Anthropic => {
725 Some(LanguageModelCacheConfiguration {
726 min_total_token: 2_048,
727 should_speculate: true,
728 max_cache_anchors: 4,
729 })
730 }
731 zed_llm_client::LanguageModelProvider::OpenAi
732 | zed_llm_client::LanguageModelProvider::Google => None,
733 }
734 }
735
736 fn count_tokens(
737 &self,
738 request: LanguageModelRequest,
739 cx: &App,
740 ) -> BoxFuture<'static, Result<u64>> {
741 match self.model.provider {
742 zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
743 zed_llm_client::LanguageModelProvider::OpenAi => {
744 let model = match open_ai::Model::from_id(&self.model.id.0) {
745 Ok(model) => model,
746 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
747 };
748 count_open_ai_tokens(request, model, cx)
749 }
750 zed_llm_client::LanguageModelProvider::Google => {
751 let client = self.client.clone();
752 let llm_api_token = self.llm_api_token.clone();
753 let model_id = self.model.id.to_string();
754 let generate_content_request =
755 into_google(request, model_id.clone(), GoogleModelMode::Default);
756 async move {
757 let http_client = &client.http_client();
758 let token = llm_api_token.acquire(&client).await?;
759
760 let request_body = CountTokensBody {
761 provider: zed_llm_client::LanguageModelProvider::Google,
762 model: model_id,
763 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
764 generate_content_request,
765 })?,
766 };
767 let request = http_client::Request::builder()
768 .method(Method::POST)
769 .uri(
770 http_client
771 .build_zed_llm_url("/count_tokens", &[])?
772 .as_ref(),
773 )
774 .header("Content-Type", "application/json")
775 .header("Authorization", format!("Bearer {token}"))
776 .body(serde_json::to_string(&request_body)?.into())?;
777 let mut response = http_client.send(request).await?;
778 let status = response.status();
779 let mut response_body = String::new();
780 response
781 .body_mut()
782 .read_to_string(&mut response_body)
783 .await?;
784
785 if status.is_success() {
786 let response_body: CountTokensResponse =
787 serde_json::from_str(&response_body)?;
788
789 Ok(response_body.tokens as u64)
790 } else {
791 Err(anyhow!(ApiError {
792 status,
793 body: response_body
794 }))
795 }
796 }
797 .boxed()
798 }
799 }
800 }
801
802 fn stream_completion(
803 &self,
804 request: LanguageModelRequest,
805 cx: &AsyncApp,
806 ) -> BoxFuture<
807 'static,
808 Result<
809 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
810 LanguageModelCompletionError,
811 >,
812 > {
813 let thread_id = request.thread_id.clone();
814 let prompt_id = request.prompt_id.clone();
815 let intent = request.intent;
816 let mode = request.mode;
817 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
818 match self.model.provider {
819 zed_llm_client::LanguageModelProvider::Anthropic => {
820 let request = into_anthropic(
821 request,
822 self.model.id.to_string(),
823 1.0,
824 self.model.max_output_tokens as u64,
825 if self.model.id.0.ends_with("-thinking") {
826 AnthropicModelMode::Thinking {
827 budget_tokens: Some(4_096),
828 }
829 } else {
830 AnthropicModelMode::Default
831 },
832 );
833 let client = self.client.clone();
834 let llm_api_token = self.llm_api_token.clone();
835 let future = self.request_limiter.stream(async move {
836 let PerformLlmCompletionResponse {
837 response,
838 usage,
839 includes_status_messages,
840 tool_use_limit_reached,
841 } = Self::perform_llm_completion(
842 client.clone(),
843 llm_api_token,
844 app_version,
845 CompletionBody {
846 thread_id,
847 prompt_id,
848 intent,
849 mode,
850 provider: zed_llm_client::LanguageModelProvider::Anthropic,
851 model: request.model.clone(),
852 provider_request: serde_json::to_value(&request)
853 .map_err(|e| anyhow!(e))?,
854 },
855 )
856 .await
857 .map_err(|err| match err.downcast::<ApiError>() {
858 Ok(api_err) => {
859 if api_err.status == StatusCode::BAD_REQUEST {
860 if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
861 return anyhow!(
862 LanguageModelKnownError::ContextWindowLimitExceeded {
863 tokens
864 }
865 );
866 }
867 }
868 anyhow!(api_err)
869 }
870 Err(err) => anyhow!(err),
871 })?;
872
873 let mut mapper = AnthropicEventMapper::new();
874 Ok(map_cloud_completion_events(
875 Box::pin(
876 response_lines(response, includes_status_messages)
877 .chain(usage_updated_event(usage))
878 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
879 ),
880 move |event| mapper.map_event(event),
881 ))
882 });
883 async move { Ok(future.await?.boxed()) }.boxed()
884 }
885 zed_llm_client::LanguageModelProvider::OpenAi => {
886 let client = self.client.clone();
887 let model = match open_ai::Model::from_id(&self.model.id.0) {
888 Ok(model) => model,
889 Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
890 };
891 let request = into_open_ai(
892 request,
893 model.id(),
894 model.supports_parallel_tool_calls(),
895 None,
896 );
897 let llm_api_token = self.llm_api_token.clone();
898 let future = self.request_limiter.stream(async move {
899 let PerformLlmCompletionResponse {
900 response,
901 usage,
902 includes_status_messages,
903 tool_use_limit_reached,
904 } = Self::perform_llm_completion(
905 client.clone(),
906 llm_api_token,
907 app_version,
908 CompletionBody {
909 thread_id,
910 prompt_id,
911 intent,
912 mode,
913 provider: zed_llm_client::LanguageModelProvider::OpenAi,
914 model: request.model.clone(),
915 provider_request: serde_json::to_value(&request)
916 .map_err(|e| anyhow!(e))?,
917 },
918 )
919 .await?;
920
921 let mut mapper = OpenAiEventMapper::new();
922 Ok(map_cloud_completion_events(
923 Box::pin(
924 response_lines(response, includes_status_messages)
925 .chain(usage_updated_event(usage))
926 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
927 ),
928 move |event| mapper.map_event(event),
929 ))
930 });
931 async move { Ok(future.await?.boxed()) }.boxed()
932 }
933 zed_llm_client::LanguageModelProvider::Google => {
934 let client = self.client.clone();
935 let request =
936 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
937 let llm_api_token = self.llm_api_token.clone();
938 let future = self.request_limiter.stream(async move {
939 let PerformLlmCompletionResponse {
940 response,
941 usage,
942 includes_status_messages,
943 tool_use_limit_reached,
944 } = Self::perform_llm_completion(
945 client.clone(),
946 llm_api_token,
947 app_version,
948 CompletionBody {
949 thread_id,
950 prompt_id,
951 intent,
952 mode,
953 provider: zed_llm_client::LanguageModelProvider::Google,
954 model: request.model.model_id.clone(),
955 provider_request: serde_json::to_value(&request)
956 .map_err(|e| anyhow!(e))?,
957 },
958 )
959 .await?;
960
961 let mut mapper = GoogleEventMapper::new();
962 Ok(map_cloud_completion_events(
963 Box::pin(
964 response_lines(response, includes_status_messages)
965 .chain(usage_updated_event(usage))
966 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
967 ),
968 move |event| mapper.map_event(event),
969 ))
970 });
971 async move { Ok(future.await?.boxed()) }.boxed()
972 }
973 }
974 }
975}
976
977#[derive(Serialize, Deserialize)]
978#[serde(rename_all = "snake_case")]
979pub enum CloudCompletionEvent<T> {
980 Status(CompletionRequestStatus),
981 Event(T),
982}
983
984fn map_cloud_completion_events<T, F>(
985 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
986 mut map_callback: F,
987) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
988where
989 T: DeserializeOwned + 'static,
990 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
991 + Send
992 + 'static,
993{
994 stream
995 .flat_map(move |event| {
996 futures::stream::iter(match event {
997 Err(error) => {
998 vec![Err(LanguageModelCompletionError::Other(error))]
999 }
1000 Ok(CloudCompletionEvent::Status(event)) => {
1001 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1002 }
1003 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1004 })
1005 })
1006 .boxed()
1007}
1008
1009fn usage_updated_event<T>(
1010 usage: Option<ModelRequestUsage>,
1011) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1012 futures::stream::iter(usage.map(|usage| {
1013 Ok(CloudCompletionEvent::Status(
1014 CompletionRequestStatus::UsageUpdated {
1015 amount: usage.amount as usize,
1016 limit: usage.limit,
1017 },
1018 ))
1019 }))
1020}
1021
1022fn tool_use_limit_reached_event<T>(
1023 tool_use_limit_reached: bool,
1024) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1025 futures::stream::iter(tool_use_limit_reached.then(|| {
1026 Ok(CloudCompletionEvent::Status(
1027 CompletionRequestStatus::ToolUseLimitReached,
1028 ))
1029 }))
1030}
1031
1032fn response_lines<T: DeserializeOwned>(
1033 response: Response<AsyncBody>,
1034 includes_status_messages: bool,
1035) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1036 futures::stream::try_unfold(
1037 (String::new(), BufReader::new(response.into_body())),
1038 move |(mut line, mut body)| async move {
1039 match body.read_line(&mut line).await {
1040 Ok(0) => Ok(None),
1041 Ok(_) => {
1042 let event = if includes_status_messages {
1043 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1044 } else {
1045 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1046 };
1047
1048 line.clear();
1049 Ok(Some((event, (line, body))))
1050 }
1051 Err(e) => Err(e.into()),
1052 }
1053 },
1054 )
1055}
1056
1057struct ConfigurationView {
1058 state: gpui::Entity<State>,
1059}
1060
1061impl ConfigurationView {
1062 fn authenticate(&mut self, cx: &mut Context<Self>) {
1063 self.state.update(cx, |state, cx| {
1064 state.authenticate(cx).detach_and_log_err(cx);
1065 });
1066 cx.notify();
1067 }
1068}
1069
1070impl Render for ConfigurationView {
1071 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1072 const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1073
1074 let is_connected = !self.state.read(cx).is_signed_out();
1075 let user_store = self.state.read(cx).user_store.read(cx);
1076 let plan = user_store.current_plan();
1077 let subscription_period = user_store.subscription_period();
1078 let eligible_for_trial = user_store.trial_started_at().is_none();
1079 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1080
1081 let is_pro = plan == Some(proto::Plan::ZedPro);
1082 let subscription_text = match (plan, subscription_period) {
1083 (Some(proto::Plan::ZedPro), Some(_)) => {
1084 "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1085 }
1086 (Some(proto::Plan::ZedProTrial), Some(_)) => {
1087 "You have access to Zed's hosted LLMs through your Zed Pro trial."
1088 }
1089 (Some(proto::Plan::Free), Some(_)) => {
1090 "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1091 }
1092 _ => {
1093 if eligible_for_trial {
1094 "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1095 } else {
1096 "Subscribe for access to Zed's hosted LLMs."
1097 }
1098 }
1099 };
1100 let manage_subscription_buttons = if is_pro {
1101 h_flex().child(
1102 Button::new("manage_settings", "Manage Subscription")
1103 .style(ButtonStyle::Tinted(TintColor::Accent))
1104 .on_click(cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx)))),
1105 )
1106 } else {
1107 h_flex()
1108 .gap_2()
1109 .child(
1110 Button::new("learn_more", "Learn more")
1111 .style(ButtonStyle::Subtle)
1112 .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_PRICING_URL))),
1113 )
1114 .child(
1115 Button::new("upgrade", "Upgrade")
1116 .style(ButtonStyle::Subtle)
1117 .color(Color::Accent)
1118 .on_click(
1119 cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1120 ),
1121 )
1122 };
1123
1124 if is_connected {
1125 v_flex()
1126 .gap_3()
1127 .w_full()
1128 .children(render_accept_terms(
1129 self.state.clone(),
1130 LanguageModelProviderTosView::Configuration,
1131 cx,
1132 ))
1133 .when(has_accepted_terms, |this| {
1134 this.child(subscription_text)
1135 .child(manage_subscription_buttons)
1136 })
1137 } else {
1138 v_flex()
1139 .gap_2()
1140 .child(Label::new("Use Zed AI to access hosted language models."))
1141 .child(
1142 Button::new("sign_in", "Sign In")
1143 .icon_color(Color::Muted)
1144 .icon(IconName::Github)
1145 .icon_position(IconPosition::Start)
1146 .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1147 )
1148 }
1149 }
1150}