1use anthropic::{AnthropicModelMode, parse_prompt_too_long};
2use anyhow::{Context as _, Result, anyhow};
3use client::{Client, UserStore, zed_urls};
4use futures::{
5 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
6};
7use google_ai::GoogleModelMode;
8use gpui::{
9 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
10};
11use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
12use language_model::{
13 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
14 LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
15 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
16 LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
17 LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
18 ZED_CLOUD_PROVIDER_ID,
19};
20use language_model::{
21 LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
22 RefreshLlmTokenListener,
23};
24use proto::Plan;
25use release_channel::AppVersion;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use settings::SettingsStore;
29use smol::Timer;
30use smol::io::{AsyncReadExt, BufReader};
31use std::pin::Pin;
32use std::str::FromStr as _;
33use std::sync::Arc;
34use std::time::Duration;
35use thiserror::Error;
36use ui::{TintColor, prelude::*};
37use util::{ResultExt as _, maybe};
38use zed_llm_client::{
39 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
40 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
41 ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
42 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
43 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
44};
45
46use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
47use crate::provider::google::{GoogleEventMapper, into_google};
48use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
49
50pub const PROVIDER_NAME: &str = "Zed";
51
52#[derive(Default, Clone, Debug, PartialEq)]
53pub struct ZedDotDevSettings {
54 pub available_models: Vec<AvailableModel>,
55}
56
57#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
58#[serde(rename_all = "lowercase")]
59pub enum AvailableProvider {
60 Anthropic,
61 OpenAi,
62 Google,
63}
64
65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
66pub struct AvailableModel {
67 /// The provider of the language model.
68 pub provider: AvailableProvider,
69 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
70 pub name: String,
71 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
72 pub display_name: Option<String>,
73 /// The size of the context window, indicating the maximum number of tokens the model can process.
74 pub max_tokens: usize,
75 /// The maximum number of output tokens allowed by the model.
76 pub max_output_tokens: Option<u32>,
77 /// The maximum number of completion tokens allowed by the model (o1-* only)
78 pub max_completion_tokens: Option<u32>,
79 /// Override this model with a different Anthropic model for tool calls.
80 pub tool_override: Option<String>,
81 /// Indicates whether this custom model supports caching.
82 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
83 /// The default temperature to use for this model.
84 pub default_temperature: Option<f32>,
85 /// Any extra beta headers to provide when using the model.
86 #[serde(default)]
87 pub extra_beta_headers: Vec<String>,
88 /// The model's mode (e.g. thinking)
89 pub mode: Option<ModelMode>,
90}
91
92#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
93#[serde(tag = "type", rename_all = "lowercase")]
94pub enum ModelMode {
95 #[default]
96 Default,
97 Thinking {
98 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
99 budget_tokens: Option<u32>,
100 },
101}
102
103impl From<ModelMode> for AnthropicModelMode {
104 fn from(value: ModelMode) -> Self {
105 match value {
106 ModelMode::Default => AnthropicModelMode::Default,
107 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
108 }
109 }
110}
111
112pub struct CloudLanguageModelProvider {
113 client: Arc<Client>,
114 state: gpui::Entity<State>,
115 _maintain_client_status: Task<()>,
116}
117
118pub struct State {
119 client: Arc<Client>,
120 llm_api_token: LlmApiToken,
121 user_store: Entity<UserStore>,
122 status: client::Status,
123 accept_terms: Option<Task<Result<()>>>,
124 models: Vec<Arc<zed_llm_client::LanguageModel>>,
125 default_model: Option<Arc<zed_llm_client::LanguageModel>>,
126 default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
127 recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
128 _fetch_models_task: Task<()>,
129 _settings_subscription: Subscription,
130 _llm_token_subscription: Subscription,
131}
132
133impl State {
134 fn new(
135 client: Arc<Client>,
136 user_store: Entity<UserStore>,
137 status: client::Status,
138 cx: &mut Context<Self>,
139 ) -> Self {
140 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
141
142 Self {
143 client: client.clone(),
144 llm_api_token: LlmApiToken::default(),
145 user_store,
146 status,
147 accept_terms: None,
148 models: Vec::new(),
149 default_model: None,
150 default_fast_model: None,
151 recommended_models: Vec::new(),
152 _fetch_models_task: cx.spawn(async move |this, cx| {
153 maybe!(async move {
154 let (client, llm_api_token) = this
155 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
156
157 loop {
158 let status = this.read_with(cx, |this, _cx| this.status)?;
159 if matches!(status, client::Status::Connected { .. }) {
160 break;
161 }
162
163 cx.background_executor()
164 .timer(Duration::from_millis(100))
165 .await;
166 }
167
168 let response = Self::fetch_models(client, llm_api_token).await?;
169 cx.update(|cx| {
170 this.update(cx, |this, cx| {
171 let mut models = Vec::new();
172
173 for model in response.models {
174 models.push(Arc::new(model.clone()));
175
176 // Right now we represent thinking variants of models as separate models on the client,
177 // so we need to insert variants for any model that supports thinking.
178 if model.supports_thinking {
179 models.push(Arc::new(zed_llm_client::LanguageModel {
180 id: zed_llm_client::LanguageModelId(
181 format!("{}-thinking", model.id).into(),
182 ),
183 display_name: format!("{} Thinking", model.display_name),
184 ..model
185 }));
186 }
187 }
188
189 this.default_model = models
190 .iter()
191 .find(|model| model.id == response.default_model)
192 .cloned();
193 this.default_fast_model = models
194 .iter()
195 .find(|model| model.id == response.default_fast_model)
196 .cloned();
197 this.recommended_models = response
198 .recommended_models
199 .iter()
200 .filter_map(|id| models.iter().find(|model| &model.id == id))
201 .cloned()
202 .collect();
203 this.models = models;
204 cx.notify();
205 })
206 })??;
207
208 anyhow::Ok(())
209 })
210 .await
211 .context("failed to fetch Zed models")
212 .log_err();
213 }),
214 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
215 cx.notify();
216 }),
217 _llm_token_subscription: cx.subscribe(
218 &refresh_llm_token_listener,
219 |this, _listener, _event, cx| {
220 let client = this.client.clone();
221 let llm_api_token = this.llm_api_token.clone();
222 cx.spawn(async move |_this, _cx| {
223 llm_api_token.refresh(&client).await?;
224 anyhow::Ok(())
225 })
226 .detach_and_log_err(cx);
227 },
228 ),
229 }
230 }
231
232 fn is_signed_out(&self) -> bool {
233 self.status.is_signed_out()
234 }
235
236 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
237 let client = self.client.clone();
238 cx.spawn(async move |state, cx| {
239 client
240 .authenticate_and_connect(true, &cx)
241 .await
242 .into_response()?;
243 state.update(cx, |_, cx| cx.notify())
244 })
245 }
246
247 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
248 self.user_store
249 .read(cx)
250 .current_user_has_accepted_terms()
251 .unwrap_or(false)
252 }
253
254 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
255 let user_store = self.user_store.clone();
256 self.accept_terms = Some(cx.spawn(async move |this, cx| {
257 let _ = user_store
258 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
259 .await;
260 this.update(cx, |this, cx| {
261 this.accept_terms = None;
262 cx.notify()
263 })
264 }));
265 }
266
267 async fn fetch_models(
268 client: Arc<Client>,
269 llm_api_token: LlmApiToken,
270 ) -> Result<ListModelsResponse> {
271 let http_client = &client.http_client();
272 let token = llm_api_token.acquire(&client).await?;
273
274 let request = http_client::Request::builder()
275 .method(Method::GET)
276 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
277 .header("Authorization", format!("Bearer {token}"))
278 .body(AsyncBody::empty())?;
279 let mut response = http_client
280 .send(request)
281 .await
282 .context("failed to send list models request")?;
283
284 if response.status().is_success() {
285 let mut body = String::new();
286 response.body_mut().read_to_string(&mut body).await?;
287 return Ok(serde_json::from_str(&body)?);
288 } else {
289 let mut body = String::new();
290 response.body_mut().read_to_string(&mut body).await?;
291 anyhow::bail!(
292 "error listing models.\nStatus: {:?}\nBody: {body}",
293 response.status(),
294 );
295 }
296 }
297}
298
299impl CloudLanguageModelProvider {
300 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
301 let mut status_rx = client.status();
302 let status = *status_rx.borrow();
303
304 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
305
306 let state_ref = state.downgrade();
307 let maintain_client_status = cx.spawn(async move |cx| {
308 while let Some(status) = status_rx.next().await {
309 if let Some(this) = state_ref.upgrade() {
310 _ = this.update(cx, |this, cx| {
311 if this.status != status {
312 this.status = status;
313 cx.notify();
314 }
315 });
316 } else {
317 break;
318 }
319 }
320 });
321
322 Self {
323 client,
324 state: state.clone(),
325 _maintain_client_status: maintain_client_status,
326 }
327 }
328
329 fn create_language_model(
330 &self,
331 model: Arc<zed_llm_client::LanguageModel>,
332 llm_api_token: LlmApiToken,
333 ) -> Arc<dyn LanguageModel> {
334 Arc::new(CloudLanguageModel {
335 id: LanguageModelId(SharedString::from(model.id.0.clone())),
336 model,
337 llm_api_token: llm_api_token.clone(),
338 client: self.client.clone(),
339 request_limiter: RateLimiter::new(4),
340 })
341 }
342}
343
344impl LanguageModelProviderState for CloudLanguageModelProvider {
345 type ObservableEntity = State;
346
347 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
348 Some(self.state.clone())
349 }
350}
351
352impl LanguageModelProvider for CloudLanguageModelProvider {
353 fn id(&self) -> LanguageModelProviderId {
354 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
355 }
356
357 fn name(&self) -> LanguageModelProviderName {
358 LanguageModelProviderName(PROVIDER_NAME.into())
359 }
360
361 fn icon(&self) -> IconName {
362 IconName::AiZed
363 }
364
365 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
366 let default_model = self.state.read(cx).default_model.clone()?;
367 let llm_api_token = self.state.read(cx).llm_api_token.clone();
368 Some(self.create_language_model(default_model, llm_api_token))
369 }
370
371 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
372 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
373 let llm_api_token = self.state.read(cx).llm_api_token.clone();
374 Some(self.create_language_model(default_fast_model, llm_api_token))
375 }
376
377 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
378 let llm_api_token = self.state.read(cx).llm_api_token.clone();
379 self.state
380 .read(cx)
381 .recommended_models
382 .iter()
383 .cloned()
384 .map(|model| self.create_language_model(model, llm_api_token.clone()))
385 .collect()
386 }
387
388 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
389 let llm_api_token = self.state.read(cx).llm_api_token.clone();
390 self.state
391 .read(cx)
392 .models
393 .iter()
394 .cloned()
395 .map(|model| self.create_language_model(model, llm_api_token.clone()))
396 .collect()
397 }
398
399 fn is_authenticated(&self, cx: &App) -> bool {
400 !self.state.read(cx).is_signed_out()
401 }
402
403 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
404 Task::ready(Ok(()))
405 }
406
407 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
408 cx.new(|_| ConfigurationView {
409 state: self.state.clone(),
410 })
411 .into()
412 }
413
414 fn must_accept_terms(&self, cx: &App) -> bool {
415 !self.state.read(cx).has_accepted_terms_of_service(cx)
416 }
417
418 fn render_accept_terms(
419 &self,
420 view: LanguageModelProviderTosView,
421 cx: &mut App,
422 ) -> Option<AnyElement> {
423 render_accept_terms(self.state.clone(), view, cx)
424 }
425
426 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
427 Task::ready(Ok(()))
428 }
429}
430
431fn render_accept_terms(
432 state: Entity<State>,
433 view_kind: LanguageModelProviderTosView,
434 cx: &mut App,
435) -> Option<AnyElement> {
436 if state.read(cx).has_accepted_terms_of_service(cx) {
437 return None;
438 }
439
440 let accept_terms_disabled = state.read(cx).accept_terms.is_some();
441
442 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
443 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
444
445 let terms_button = Button::new("terms_of_service", "Terms of Service")
446 .style(ButtonStyle::Subtle)
447 .icon(IconName::ArrowUpRight)
448 .icon_color(Color::Muted)
449 .icon_size(IconSize::XSmall)
450 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
451 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
452
453 let button_container = h_flex().child(
454 Button::new("accept_terms", "I accept the Terms of Service")
455 .when(!thread_empty_state, |this| {
456 this.full_width()
457 .style(ButtonStyle::Tinted(TintColor::Accent))
458 .icon(IconName::Check)
459 .icon_position(IconPosition::Start)
460 .icon_size(IconSize::Small)
461 })
462 .when(thread_empty_state, |this| {
463 this.style(ButtonStyle::Tinted(TintColor::Warning))
464 .label_size(LabelSize::Small)
465 })
466 .disabled(accept_terms_disabled)
467 .on_click({
468 let state = state.downgrade();
469 move |_, _window, cx| {
470 state
471 .update(cx, |state, cx| state.accept_terms_of_service(cx))
472 .ok();
473 }
474 }),
475 );
476
477 let form = if thread_empty_state {
478 h_flex()
479 .w_full()
480 .flex_wrap()
481 .justify_between()
482 .child(
483 h_flex()
484 .child(
485 Label::new("To start using Zed AI, please read and accept the")
486 .size(LabelSize::Small),
487 )
488 .child(terms_button),
489 )
490 .child(button_container)
491 } else {
492 v_flex()
493 .w_full()
494 .gap_2()
495 .child(
496 h_flex()
497 .flex_wrap()
498 .when(thread_fresh_start, |this| this.justify_center())
499 .child(Label::new(
500 "To start using Zed AI, please read and accept the",
501 ))
502 .child(terms_button),
503 )
504 .child({
505 match view_kind {
506 LanguageModelProviderTosView::PromptEditorPopup => {
507 button_container.w_full().justify_end()
508 }
509 LanguageModelProviderTosView::Configuration => {
510 button_container.w_full().justify_start()
511 }
512 LanguageModelProviderTosView::ThreadFreshStart => {
513 button_container.w_full().justify_center()
514 }
515 LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
516 }
517 })
518 };
519
520 Some(form.into_any())
521}
522
523pub struct CloudLanguageModel {
524 id: LanguageModelId,
525 model: Arc<zed_llm_client::LanguageModel>,
526 llm_api_token: LlmApiToken,
527 client: Arc<Client>,
528 request_limiter: RateLimiter,
529}
530
531struct PerformLlmCompletionResponse {
532 response: Response<AsyncBody>,
533 usage: Option<RequestUsage>,
534 tool_use_limit_reached: bool,
535 includes_status_messages: bool,
536}
537
538impl CloudLanguageModel {
539 const MAX_RETRIES: usize = 3;
540
541 async fn perform_llm_completion(
542 client: Arc<Client>,
543 llm_api_token: LlmApiToken,
544 app_version: Option<SemanticVersion>,
545 body: CompletionBody,
546 ) -> Result<PerformLlmCompletionResponse> {
547 let http_client = &client.http_client();
548
549 let mut token = llm_api_token.acquire(&client).await?;
550 let mut retries_remaining = Self::MAX_RETRIES;
551 let mut retry_delay = Duration::from_secs(1);
552
553 loop {
554 let request_builder = http_client::Request::builder()
555 .method(Method::POST)
556 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
557 let request_builder = if let Some(app_version) = app_version {
558 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
559 } else {
560 request_builder
561 };
562
563 let request = request_builder
564 .header("Content-Type", "application/json")
565 .header("Authorization", format!("Bearer {token}"))
566 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
567 .body(serde_json::to_string(&body)?.into())?;
568 let mut response = http_client.send(request).await?;
569 let status = response.status();
570 if status.is_success() {
571 let includes_status_messages = response
572 .headers()
573 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
574 .is_some();
575
576 let tool_use_limit_reached = response
577 .headers()
578 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
579 .is_some();
580
581 let usage = if includes_status_messages {
582 None
583 } else {
584 RequestUsage::from_headers(response.headers()).ok()
585 };
586
587 return Ok(PerformLlmCompletionResponse {
588 response,
589 usage,
590 includes_status_messages,
591 tool_use_limit_reached,
592 });
593 } else if response
594 .headers()
595 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
596 .is_some()
597 {
598 retries_remaining -= 1;
599 token = llm_api_token.refresh(&client).await?;
600 } else if status == StatusCode::FORBIDDEN
601 && response
602 .headers()
603 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
604 .is_some()
605 {
606 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
607 .headers()
608 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
609 .and_then(|resource| resource.to_str().ok())
610 {
611 if let Some(plan) = response
612 .headers()
613 .get(CURRENT_PLAN_HEADER_NAME)
614 .and_then(|plan| plan.to_str().ok())
615 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
616 {
617 let plan = match plan {
618 zed_llm_client::Plan::ZedFree => Plan::Free,
619 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
620 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
621 };
622 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
623 }
624 }
625
626 anyhow::bail!("Forbidden");
627 } else if status.as_u16() >= 500 && status.as_u16() < 600 {
628 // If we encounter an error in the 500 range, retry after a delay.
629 // We've seen at least these in the wild from API providers:
630 // * 500 Internal Server Error
631 // * 502 Bad Gateway
632 // * 529 Service Overloaded
633
634 if retries_remaining == 0 {
635 let mut body = String::new();
636 response.body_mut().read_to_string(&mut body).await?;
637 anyhow::bail!(
638 "cloud language model completion failed after {} retries with status {status}: {body}",
639 Self::MAX_RETRIES
640 );
641 }
642
643 Timer::after(retry_delay).await;
644
645 retries_remaining -= 1;
646 retry_delay *= 2; // If it fails again, wait longer.
647 } else if status == StatusCode::PAYMENT_REQUIRED {
648 return Err(anyhow!(PaymentRequiredError));
649 } else {
650 let mut body = String::new();
651 response.body_mut().read_to_string(&mut body).await?;
652 return Err(anyhow!(ApiError { status, body }));
653 }
654 }
655 }
656}
657
658#[derive(Debug, Error)]
659#[error("cloud language model request failed with status {status}: {body}")]
660struct ApiError {
661 status: StatusCode,
662 body: String,
663}
664
665impl LanguageModel for CloudLanguageModel {
666 fn id(&self) -> LanguageModelId {
667 self.id.clone()
668 }
669
670 fn name(&self) -> LanguageModelName {
671 LanguageModelName::from(self.model.display_name.clone())
672 }
673
674 fn provider_id(&self) -> LanguageModelProviderId {
675 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
676 }
677
678 fn provider_name(&self) -> LanguageModelProviderName {
679 LanguageModelProviderName(PROVIDER_NAME.into())
680 }
681
682 fn supports_tools(&self) -> bool {
683 self.model.supports_tools
684 }
685
686 fn supports_images(&self) -> bool {
687 self.model.supports_images
688 }
689
690 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
691 match choice {
692 LanguageModelToolChoice::Auto
693 | LanguageModelToolChoice::Any
694 | LanguageModelToolChoice::None => true,
695 }
696 }
697
698 fn supports_max_mode(&self) -> bool {
699 self.model.supports_max_mode
700 }
701
702 fn telemetry_id(&self) -> String {
703 format!("zed.dev/{}", self.model.id)
704 }
705
706 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
707 match self.model.provider {
708 zed_llm_client::LanguageModelProvider::Anthropic
709 | zed_llm_client::LanguageModelProvider::OpenAi => {
710 LanguageModelToolSchemaFormat::JsonSchema
711 }
712 zed_llm_client::LanguageModelProvider::Google => {
713 LanguageModelToolSchemaFormat::JsonSchemaSubset
714 }
715 }
716 }
717
718 fn max_token_count(&self) -> usize {
719 self.model.max_token_count
720 }
721
722 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
723 match &self.model.provider {
724 zed_llm_client::LanguageModelProvider::Anthropic => {
725 Some(LanguageModelCacheConfiguration {
726 min_total_token: 2_048,
727 should_speculate: true,
728 max_cache_anchors: 4,
729 })
730 }
731 zed_llm_client::LanguageModelProvider::OpenAi
732 | zed_llm_client::LanguageModelProvider::Google => None,
733 }
734 }
735
736 fn count_tokens(
737 &self,
738 request: LanguageModelRequest,
739 cx: &App,
740 ) -> BoxFuture<'static, Result<usize>> {
741 match self.model.provider {
742 zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
743 zed_llm_client::LanguageModelProvider::OpenAi => {
744 let model = match open_ai::Model::from_id(&self.model.id.0) {
745 Ok(model) => model,
746 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
747 };
748 count_open_ai_tokens(request, model, cx)
749 }
750 zed_llm_client::LanguageModelProvider::Google => {
751 let client = self.client.clone();
752 let llm_api_token = self.llm_api_token.clone();
753 let model_id = self.model.id.to_string();
754 let generate_content_request =
755 into_google(request, model_id.clone(), GoogleModelMode::Default);
756 async move {
757 let http_client = &client.http_client();
758 let token = llm_api_token.acquire(&client).await?;
759
760 let request_body = CountTokensBody {
761 provider: zed_llm_client::LanguageModelProvider::Google,
762 model: model_id,
763 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
764 generate_content_request,
765 })?,
766 };
767 let request = http_client::Request::builder()
768 .method(Method::POST)
769 .uri(
770 http_client
771 .build_zed_llm_url("/count_tokens", &[])?
772 .as_ref(),
773 )
774 .header("Content-Type", "application/json")
775 .header("Authorization", format!("Bearer {token}"))
776 .body(serde_json::to_string(&request_body)?.into())?;
777 let mut response = http_client.send(request).await?;
778 let status = response.status();
779 let mut response_body = String::new();
780 response
781 .body_mut()
782 .read_to_string(&mut response_body)
783 .await?;
784
785 if status.is_success() {
786 let response_body: CountTokensResponse =
787 serde_json::from_str(&response_body)?;
788
789 Ok(response_body.tokens)
790 } else {
791 Err(anyhow!(ApiError {
792 status,
793 body: response_body
794 }))
795 }
796 }
797 .boxed()
798 }
799 }
800 }
801
802 fn stream_completion(
803 &self,
804 request: LanguageModelRequest,
805 cx: &AsyncApp,
806 ) -> BoxFuture<
807 'static,
808 Result<
809 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
810 >,
811 > {
812 let thread_id = request.thread_id.clone();
813 let prompt_id = request.prompt_id.clone();
814 let intent = request.intent;
815 let mode = request.mode;
816 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
817 match self.model.provider {
818 zed_llm_client::LanguageModelProvider::Anthropic => {
819 let request = into_anthropic(
820 request,
821 self.model.id.to_string(),
822 1.0,
823 self.model.max_output_tokens as u32,
824 if self.model.id.0.ends_with("-thinking") {
825 AnthropicModelMode::Thinking {
826 budget_tokens: Some(4_096),
827 }
828 } else {
829 AnthropicModelMode::Default
830 },
831 );
832 let client = self.client.clone();
833 let llm_api_token = self.llm_api_token.clone();
834 let future = self.request_limiter.stream(async move {
835 let PerformLlmCompletionResponse {
836 response,
837 usage,
838 includes_status_messages,
839 tool_use_limit_reached,
840 } = Self::perform_llm_completion(
841 client.clone(),
842 llm_api_token,
843 app_version,
844 CompletionBody {
845 thread_id,
846 prompt_id,
847 intent,
848 mode,
849 provider: zed_llm_client::LanguageModelProvider::Anthropic,
850 model: request.model.clone(),
851 provider_request: serde_json::to_value(&request)?,
852 },
853 )
854 .await
855 .map_err(|err| match err.downcast::<ApiError>() {
856 Ok(api_err) => {
857 if api_err.status == StatusCode::BAD_REQUEST {
858 if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
859 return anyhow!(
860 LanguageModelKnownError::ContextWindowLimitExceeded {
861 tokens
862 }
863 );
864 }
865 }
866 anyhow!(api_err)
867 }
868 Err(err) => anyhow!(err),
869 })?;
870
871 let mut mapper = AnthropicEventMapper::new();
872 Ok(map_cloud_completion_events(
873 Box::pin(
874 response_lines(response, includes_status_messages)
875 .chain(usage_updated_event(usage))
876 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
877 ),
878 move |event| mapper.map_event(event),
879 ))
880 });
881 async move { Ok(future.await?.boxed()) }.boxed()
882 }
883 zed_llm_client::LanguageModelProvider::OpenAi => {
884 let client = self.client.clone();
885 let model = match open_ai::Model::from_id(&self.model.id.0) {
886 Ok(model) => model,
887 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
888 };
889 let request = into_open_ai(request, &model, None);
890 let llm_api_token = self.llm_api_token.clone();
891 let future = self.request_limiter.stream(async move {
892 let PerformLlmCompletionResponse {
893 response,
894 usage,
895 includes_status_messages,
896 tool_use_limit_reached,
897 } = Self::perform_llm_completion(
898 client.clone(),
899 llm_api_token,
900 app_version,
901 CompletionBody {
902 thread_id,
903 prompt_id,
904 intent,
905 mode,
906 provider: zed_llm_client::LanguageModelProvider::OpenAi,
907 model: request.model.clone(),
908 provider_request: serde_json::to_value(&request)?,
909 },
910 )
911 .await?;
912
913 let mut mapper = OpenAiEventMapper::new();
914 Ok(map_cloud_completion_events(
915 Box::pin(
916 response_lines(response, includes_status_messages)
917 .chain(usage_updated_event(usage))
918 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
919 ),
920 move |event| mapper.map_event(event),
921 ))
922 });
923 async move { Ok(future.await?.boxed()) }.boxed()
924 }
925 zed_llm_client::LanguageModelProvider::Google => {
926 let client = self.client.clone();
927 let request =
928 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
929 let llm_api_token = self.llm_api_token.clone();
930 let future = self.request_limiter.stream(async move {
931 let PerformLlmCompletionResponse {
932 response,
933 usage,
934 includes_status_messages,
935 tool_use_limit_reached,
936 } = Self::perform_llm_completion(
937 client.clone(),
938 llm_api_token,
939 app_version,
940 CompletionBody {
941 thread_id,
942 prompt_id,
943 intent,
944 mode,
945 provider: zed_llm_client::LanguageModelProvider::Google,
946 model: request.model.model_id.clone(),
947 provider_request: serde_json::to_value(&request)?,
948 },
949 )
950 .await?;
951
952 let mut mapper = GoogleEventMapper::new();
953 Ok(map_cloud_completion_events(
954 Box::pin(
955 response_lines(response, includes_status_messages)
956 .chain(usage_updated_event(usage))
957 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
958 ),
959 move |event| mapper.map_event(event),
960 ))
961 });
962 async move { Ok(future.await?.boxed()) }.boxed()
963 }
964 }
965 }
966}
967
968#[derive(Serialize, Deserialize)]
969#[serde(rename_all = "snake_case")]
970pub enum CloudCompletionEvent<T> {
971 Status(CompletionRequestStatus),
972 Event(T),
973}
974
975fn map_cloud_completion_events<T, F>(
976 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
977 mut map_callback: F,
978) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
979where
980 T: DeserializeOwned + 'static,
981 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
982 + Send
983 + 'static,
984{
985 stream
986 .flat_map(move |event| {
987 futures::stream::iter(match event {
988 Err(error) => {
989 vec![Err(LanguageModelCompletionError::Other(error))]
990 }
991 Ok(CloudCompletionEvent::Status(event)) => {
992 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
993 }
994 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
995 })
996 })
997 .boxed()
998}
999
1000fn usage_updated_event<T>(
1001 usage: Option<RequestUsage>,
1002) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1003 futures::stream::iter(usage.map(|usage| {
1004 Ok(CloudCompletionEvent::Status(
1005 CompletionRequestStatus::UsageUpdated {
1006 amount: usage.amount as usize,
1007 limit: usage.limit,
1008 },
1009 ))
1010 }))
1011}
1012
1013fn tool_use_limit_reached_event<T>(
1014 tool_use_limit_reached: bool,
1015) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1016 futures::stream::iter(tool_use_limit_reached.then(|| {
1017 Ok(CloudCompletionEvent::Status(
1018 CompletionRequestStatus::ToolUseLimitReached,
1019 ))
1020 }))
1021}
1022
1023fn response_lines<T: DeserializeOwned>(
1024 response: Response<AsyncBody>,
1025 includes_status_messages: bool,
1026) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1027 futures::stream::try_unfold(
1028 (String::new(), BufReader::new(response.into_body())),
1029 move |(mut line, mut body)| async move {
1030 match body.read_line(&mut line).await {
1031 Ok(0) => Ok(None),
1032 Ok(_) => {
1033 let event = if includes_status_messages {
1034 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1035 } else {
1036 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1037 };
1038
1039 line.clear();
1040 Ok(Some((event, (line, body))))
1041 }
1042 Err(e) => Err(e.into()),
1043 }
1044 },
1045 )
1046}
1047
1048struct ConfigurationView {
1049 state: gpui::Entity<State>,
1050}
1051
1052impl ConfigurationView {
1053 fn authenticate(&mut self, cx: &mut Context<Self>) {
1054 self.state.update(cx, |state, cx| {
1055 state.authenticate(cx).detach_and_log_err(cx);
1056 });
1057 cx.notify();
1058 }
1059}
1060
1061impl Render for ConfigurationView {
1062 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1063 const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1064
1065 let is_connected = !self.state.read(cx).is_signed_out();
1066 let user_store = self.state.read(cx).user_store.read(cx);
1067 let plan = user_store.current_plan();
1068 let subscription_period = user_store.subscription_period();
1069 let eligible_for_trial = user_store.trial_started_at().is_none();
1070 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1071
1072 let is_pro = plan == Some(proto::Plan::ZedPro);
1073 let subscription_text = match (plan, subscription_period) {
1074 (Some(proto::Plan::ZedPro), Some(_)) => {
1075 "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1076 }
1077 (Some(proto::Plan::ZedProTrial), Some(_)) => {
1078 "You have access to Zed's hosted LLMs through your Zed Pro trial."
1079 }
1080 (Some(proto::Plan::Free), Some(_)) => {
1081 "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1082 }
1083 _ => {
1084 if eligible_for_trial {
1085 "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1086 } else {
1087 "Subscribe for access to Zed's hosted LLMs."
1088 }
1089 }
1090 };
1091 let manage_subscription_buttons = if is_pro {
1092 h_flex().child(
1093 Button::new("manage_settings", "Manage Subscription")
1094 .style(ButtonStyle::Tinted(TintColor::Accent))
1095 .on_click(cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx)))),
1096 )
1097 } else {
1098 h_flex()
1099 .gap_2()
1100 .child(
1101 Button::new("learn_more", "Learn more")
1102 .style(ButtonStyle::Subtle)
1103 .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_PRICING_URL))),
1104 )
1105 .child(
1106 Button::new("upgrade", "Upgrade")
1107 .style(ButtonStyle::Subtle)
1108 .color(Color::Accent)
1109 .on_click(
1110 cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1111 ),
1112 )
1113 };
1114
1115 if is_connected {
1116 v_flex()
1117 .gap_3()
1118 .w_full()
1119 .children(render_accept_terms(
1120 self.state.clone(),
1121 LanguageModelProviderTosView::Configuration,
1122 cx,
1123 ))
1124 .when(has_accepted_terms, |this| {
1125 this.child(subscription_text)
1126 .child(manage_subscription_buttons)
1127 })
1128 } else {
1129 v_flex()
1130 .gap_2()
1131 .child(Label::new("Use Zed AI to access hosted language models."))
1132 .child(
1133 Button::new("sign_in", "Sign In")
1134 .icon_color(Color::Muted)
1135 .icon(IconName::Github)
1136 .icon_position(IconPosition::Start)
1137 .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1138 )
1139 }
1140 }
1141}