1use anthropic::{AnthropicModelMode, parse_prompt_too_long};
2use anyhow::{Context as _, Result, anyhow};
3use client::{Client, ModelRequestUsage, UserStore, zed_urls};
4use futures::{
5 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
6};
7use google_ai::GoogleModelMode;
8use gpui::{
9 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
10};
11use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
12use language_model::{
13 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
14 LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
15 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
16 LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
17 LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
18 ZED_CLOUD_PROVIDER_ID,
19};
20use language_model::{
21 LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
22 RefreshLlmTokenListener,
23};
24use proto::Plan;
25use release_channel::AppVersion;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use settings::SettingsStore;
29use smol::Timer;
30use smol::io::{AsyncReadExt, BufReader};
31use std::pin::Pin;
32use std::str::FromStr as _;
33use std::sync::Arc;
34use std::time::Duration;
35use thiserror::Error;
36use ui::{TintColor, prelude::*};
37use util::{ResultExt as _, maybe};
38use zed_llm_client::{
39 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
40 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
41 ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
42 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
43 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
44};
45
46use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
47use crate::provider::google::{GoogleEventMapper, into_google};
48use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
49
50pub const PROVIDER_NAME: &str = "Zed";
51
52#[derive(Default, Clone, Debug, PartialEq)]
53pub struct ZedDotDevSettings {
54 pub available_models: Vec<AvailableModel>,
55}
56
57#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
58#[serde(rename_all = "lowercase")]
59pub enum AvailableProvider {
60 Anthropic,
61 OpenAi,
62 Google,
63}
64
65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
66pub struct AvailableModel {
67 /// The provider of the language model.
68 pub provider: AvailableProvider,
69 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
70 pub name: String,
71 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
72 pub display_name: Option<String>,
73 /// The size of the context window, indicating the maximum number of tokens the model can process.
74 pub max_tokens: usize,
75 /// The maximum number of output tokens allowed by the model.
76 pub max_output_tokens: Option<u64>,
77 /// The maximum number of completion tokens allowed by the model (o1-* only)
78 pub max_completion_tokens: Option<u64>,
79 /// Override this model with a different Anthropic model for tool calls.
80 pub tool_override: Option<String>,
81 /// Indicates whether this custom model supports caching.
82 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
83 /// The default temperature to use for this model.
84 pub default_temperature: Option<f32>,
85 /// Any extra beta headers to provide when using the model.
86 #[serde(default)]
87 pub extra_beta_headers: Vec<String>,
88 /// The model's mode (e.g. thinking)
89 pub mode: Option<ModelMode>,
90}
91
92#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
93#[serde(tag = "type", rename_all = "lowercase")]
94pub enum ModelMode {
95 #[default]
96 Default,
97 Thinking {
98 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
99 budget_tokens: Option<u32>,
100 },
101}
102
103impl From<ModelMode> for AnthropicModelMode {
104 fn from(value: ModelMode) -> Self {
105 match value {
106 ModelMode::Default => AnthropicModelMode::Default,
107 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
108 }
109 }
110}
111
112pub struct CloudLanguageModelProvider {
113 client: Arc<Client>,
114 state: gpui::Entity<State>,
115 _maintain_client_status: Task<()>,
116}
117
118pub struct State {
119 client: Arc<Client>,
120 llm_api_token: LlmApiToken,
121 user_store: Entity<UserStore>,
122 status: client::Status,
123 accept_terms: Option<Task<Result<()>>>,
124 models: Vec<Arc<zed_llm_client::LanguageModel>>,
125 default_model: Option<Arc<zed_llm_client::LanguageModel>>,
126 default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
127 recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
128 _fetch_models_task: Task<()>,
129 _settings_subscription: Subscription,
130 _llm_token_subscription: Subscription,
131}
132
133impl State {
134 fn new(
135 client: Arc<Client>,
136 user_store: Entity<UserStore>,
137 status: client::Status,
138 cx: &mut Context<Self>,
139 ) -> Self {
140 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
141
142 Self {
143 client: client.clone(),
144 llm_api_token: LlmApiToken::default(),
145 user_store,
146 status,
147 accept_terms: None,
148 models: Vec::new(),
149 default_model: None,
150 default_fast_model: None,
151 recommended_models: Vec::new(),
152 _fetch_models_task: cx.spawn(async move |this, cx| {
153 maybe!(async move {
154 let (client, llm_api_token) = this
155 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
156
157 loop {
158 let status = this.read_with(cx, |this, _cx| this.status)?;
159 if matches!(status, client::Status::Connected { .. }) {
160 break;
161 }
162
163 cx.background_executor()
164 .timer(Duration::from_millis(100))
165 .await;
166 }
167
168 let response = Self::fetch_models(client, llm_api_token).await?;
169 cx.update(|cx| {
170 this.update(cx, |this, cx| {
171 let mut models = Vec::new();
172
173 for model in response.models {
174 models.push(Arc::new(model.clone()));
175
176 // Right now we represent thinking variants of models as separate models on the client,
177 // so we need to insert variants for any model that supports thinking.
178 if model.supports_thinking {
179 models.push(Arc::new(zed_llm_client::LanguageModel {
180 id: zed_llm_client::LanguageModelId(
181 format!("{}-thinking", model.id).into(),
182 ),
183 display_name: format!("{} Thinking", model.display_name),
184 ..model
185 }));
186 }
187 }
188
189 this.default_model = models
190 .iter()
191 .find(|model| model.id == response.default_model)
192 .cloned();
193 this.default_fast_model = models
194 .iter()
195 .find(|model| model.id == response.default_fast_model)
196 .cloned();
197 this.recommended_models = response
198 .recommended_models
199 .iter()
200 .filter_map(|id| models.iter().find(|model| &model.id == id))
201 .cloned()
202 .collect();
203 this.models = models;
204 cx.notify();
205 })
206 })??;
207
208 anyhow::Ok(())
209 })
210 .await
211 .context("failed to fetch Zed models")
212 .log_err();
213 }),
214 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
215 cx.notify();
216 }),
217 _llm_token_subscription: cx.subscribe(
218 &refresh_llm_token_listener,
219 |this, _listener, _event, cx| {
220 let client = this.client.clone();
221 let llm_api_token = this.llm_api_token.clone();
222 cx.spawn(async move |_this, _cx| {
223 llm_api_token.refresh(&client).await?;
224 anyhow::Ok(())
225 })
226 .detach_and_log_err(cx);
227 },
228 ),
229 }
230 }
231
232 fn is_signed_out(&self) -> bool {
233 self.status.is_signed_out()
234 }
235
236 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
237 let client = self.client.clone();
238 cx.spawn(async move |state, cx| {
239 client
240 .authenticate_and_connect(true, &cx)
241 .await
242 .into_response()?;
243 state.update(cx, |_, cx| cx.notify())
244 })
245 }
246
247 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
248 self.user_store
249 .read(cx)
250 .current_user_has_accepted_terms()
251 .unwrap_or(false)
252 }
253
254 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
255 let user_store = self.user_store.clone();
256 self.accept_terms = Some(cx.spawn(async move |this, cx| {
257 let _ = user_store
258 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
259 .await;
260 this.update(cx, |this, cx| {
261 this.accept_terms = None;
262 cx.notify()
263 })
264 }));
265 }
266
267 async fn fetch_models(
268 client: Arc<Client>,
269 llm_api_token: LlmApiToken,
270 ) -> Result<ListModelsResponse> {
271 let http_client = &client.http_client();
272
273 // token is optional here, it's provided because some models are conditionally available
274 let token = llm_api_token.acquire(&client).await.ok();
275
276 let mut request = http_client::Request::builder()
277 .method(Method::GET)
278 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref());
279 if let Some(token) = token {
280 request = request.header("Authorization", format!("Bearer {token}"));
281 }
282 let request = request.body(AsyncBody::empty())?;
283 let mut response = http_client
284 .send(request)
285 .await
286 .context("failed to send list models request")?;
287
288 if response.status().is_success() {
289 let mut body = String::new();
290 response.body_mut().read_to_string(&mut body).await?;
291 return Ok(serde_json::from_str(&body)?);
292 } else {
293 let mut body = String::new();
294 response.body_mut().read_to_string(&mut body).await?;
295 anyhow::bail!(
296 "error listing models.\nStatus: {:?}\nBody: {body}",
297 response.status(),
298 );
299 }
300 }
301}
302
303impl CloudLanguageModelProvider {
304 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
305 let mut status_rx = client.status();
306 let status = *status_rx.borrow();
307
308 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
309
310 let state_ref = state.downgrade();
311 let maintain_client_status = cx.spawn(async move |cx| {
312 while let Some(status) = status_rx.next().await {
313 if let Some(this) = state_ref.upgrade() {
314 _ = this.update(cx, |this, cx| {
315 if this.status != status {
316 this.status = status;
317 cx.notify();
318 }
319 });
320 } else {
321 break;
322 }
323 }
324 });
325
326 Self {
327 client,
328 state: state.clone(),
329 _maintain_client_status: maintain_client_status,
330 }
331 }
332
333 fn create_language_model(
334 &self,
335 model: Arc<zed_llm_client::LanguageModel>,
336 llm_api_token: LlmApiToken,
337 ) -> Arc<dyn LanguageModel> {
338 Arc::new(CloudLanguageModel {
339 id: LanguageModelId(SharedString::from(model.id.0.clone())),
340 model,
341 llm_api_token: llm_api_token.clone(),
342 client: self.client.clone(),
343 request_limiter: RateLimiter::new(4),
344 })
345 }
346}
347
348impl LanguageModelProviderState for CloudLanguageModelProvider {
349 type ObservableEntity = State;
350
351 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
352 Some(self.state.clone())
353 }
354}
355
356impl LanguageModelProvider for CloudLanguageModelProvider {
357 fn id(&self) -> LanguageModelProviderId {
358 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
359 }
360
361 fn name(&self) -> LanguageModelProviderName {
362 LanguageModelProviderName(PROVIDER_NAME.into())
363 }
364
365 fn icon(&self) -> IconName {
366 IconName::AiZed
367 }
368
369 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
370 let default_model = self.state.read(cx).default_model.clone()?;
371 let llm_api_token = self.state.read(cx).llm_api_token.clone();
372 Some(self.create_language_model(default_model, llm_api_token))
373 }
374
375 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
376 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
377 let llm_api_token = self.state.read(cx).llm_api_token.clone();
378 Some(self.create_language_model(default_fast_model, llm_api_token))
379 }
380
381 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
382 let llm_api_token = self.state.read(cx).llm_api_token.clone();
383 self.state
384 .read(cx)
385 .recommended_models
386 .iter()
387 .cloned()
388 .map(|model| self.create_language_model(model, llm_api_token.clone()))
389 .collect()
390 }
391
392 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
393 let llm_api_token = self.state.read(cx).llm_api_token.clone();
394 self.state
395 .read(cx)
396 .models
397 .iter()
398 .cloned()
399 .map(|model| self.create_language_model(model, llm_api_token.clone()))
400 .collect()
401 }
402
403 fn is_authenticated(&self, cx: &App) -> bool {
404 !self.state.read(cx).is_signed_out()
405 }
406
407 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
408 Task::ready(Ok(()))
409 }
410
411 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
412 cx.new(|_| ConfigurationView {
413 state: self.state.clone(),
414 })
415 .into()
416 }
417
418 fn must_accept_terms(&self, cx: &App) -> bool {
419 !self.state.read(cx).has_accepted_terms_of_service(cx)
420 }
421
422 fn render_accept_terms(
423 &self,
424 view: LanguageModelProviderTosView,
425 cx: &mut App,
426 ) -> Option<AnyElement> {
427 render_accept_terms(self.state.clone(), view, cx)
428 }
429
430 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
431 Task::ready(Ok(()))
432 }
433}
434
435fn render_accept_terms(
436 state: Entity<State>,
437 view_kind: LanguageModelProviderTosView,
438 cx: &mut App,
439) -> Option<AnyElement> {
440 if state.read(cx).has_accepted_terms_of_service(cx) {
441 return None;
442 }
443
444 let accept_terms_disabled = state.read(cx).accept_terms.is_some();
445
446 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
447 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
448
449 let terms_button = Button::new("terms_of_service", "Terms of Service")
450 .style(ButtonStyle::Subtle)
451 .icon(IconName::ArrowUpRight)
452 .icon_color(Color::Muted)
453 .icon_size(IconSize::XSmall)
454 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
455 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
456
457 let button_container = h_flex().child(
458 Button::new("accept_terms", "I accept the Terms of Service")
459 .when(!thread_empty_state, |this| {
460 this.full_width()
461 .style(ButtonStyle::Tinted(TintColor::Accent))
462 .icon(IconName::Check)
463 .icon_position(IconPosition::Start)
464 .icon_size(IconSize::Small)
465 })
466 .when(thread_empty_state, |this| {
467 this.style(ButtonStyle::Tinted(TintColor::Warning))
468 .label_size(LabelSize::Small)
469 })
470 .disabled(accept_terms_disabled)
471 .on_click({
472 let state = state.downgrade();
473 move |_, _window, cx| {
474 state
475 .update(cx, |state, cx| state.accept_terms_of_service(cx))
476 .ok();
477 }
478 }),
479 );
480
481 let form = if thread_empty_state {
482 h_flex()
483 .w_full()
484 .flex_wrap()
485 .justify_between()
486 .child(
487 h_flex()
488 .child(
489 Label::new("To start using Zed AI, please read and accept the")
490 .size(LabelSize::Small),
491 )
492 .child(terms_button),
493 )
494 .child(button_container)
495 } else {
496 v_flex()
497 .w_full()
498 .gap_2()
499 .child(
500 h_flex()
501 .flex_wrap()
502 .when(thread_fresh_start, |this| this.justify_center())
503 .child(Label::new(
504 "To start using Zed AI, please read and accept the",
505 ))
506 .child(terms_button),
507 )
508 .child({
509 match view_kind {
510 LanguageModelProviderTosView::PromptEditorPopup => {
511 button_container.w_full().justify_end()
512 }
513 LanguageModelProviderTosView::Configuration => {
514 button_container.w_full().justify_start()
515 }
516 LanguageModelProviderTosView::ThreadFreshStart => {
517 button_container.w_full().justify_center()
518 }
519 LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
520 }
521 })
522 };
523
524 Some(form.into_any())
525}
526
527pub struct CloudLanguageModel {
528 id: LanguageModelId,
529 model: Arc<zed_llm_client::LanguageModel>,
530 llm_api_token: LlmApiToken,
531 client: Arc<Client>,
532 request_limiter: RateLimiter,
533}
534
535struct PerformLlmCompletionResponse {
536 response: Response<AsyncBody>,
537 usage: Option<ModelRequestUsage>,
538 tool_use_limit_reached: bool,
539 includes_status_messages: bool,
540}
541
542impl CloudLanguageModel {
543 const MAX_RETRIES: usize = 3;
544
545 async fn perform_llm_completion(
546 client: Arc<Client>,
547 llm_api_token: LlmApiToken,
548 app_version: Option<SemanticVersion>,
549 body: CompletionBody,
550 ) -> Result<PerformLlmCompletionResponse> {
551 let http_client = &client.http_client();
552
553 let mut token = llm_api_token.acquire(&client).await?;
554 let mut retries_remaining = Self::MAX_RETRIES;
555 let mut retry_delay = Duration::from_secs(1);
556
557 loop {
558 let request_builder = http_client::Request::builder()
559 .method(Method::POST)
560 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
561 let request_builder = if let Some(app_version) = app_version {
562 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
563 } else {
564 request_builder
565 };
566
567 let request = request_builder
568 .header("Content-Type", "application/json")
569 .header("Authorization", format!("Bearer {token}"))
570 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
571 .body(serde_json::to_string(&body)?.into())?;
572 let mut response = http_client.send(request).await?;
573 let status = response.status();
574 if status.is_success() {
575 let includes_status_messages = response
576 .headers()
577 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
578 .is_some();
579
580 let tool_use_limit_reached = response
581 .headers()
582 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
583 .is_some();
584
585 let usage = if includes_status_messages {
586 None
587 } else {
588 ModelRequestUsage::from_headers(response.headers()).ok()
589 };
590
591 return Ok(PerformLlmCompletionResponse {
592 response,
593 usage,
594 includes_status_messages,
595 tool_use_limit_reached,
596 });
597 } else if response
598 .headers()
599 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
600 .is_some()
601 {
602 retries_remaining -= 1;
603 token = llm_api_token.refresh(&client).await?;
604 } else if status == StatusCode::FORBIDDEN
605 && response
606 .headers()
607 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
608 .is_some()
609 {
610 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
611 .headers()
612 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
613 .and_then(|resource| resource.to_str().ok())
614 {
615 if let Some(plan) = response
616 .headers()
617 .get(CURRENT_PLAN_HEADER_NAME)
618 .and_then(|plan| plan.to_str().ok())
619 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
620 {
621 let plan = match plan {
622 zed_llm_client::Plan::ZedFree => Plan::Free,
623 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
624 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
625 };
626 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
627 }
628 }
629
630 anyhow::bail!("Forbidden");
631 } else if status.as_u16() >= 500 && status.as_u16() < 600 {
632 // If we encounter an error in the 500 range, retry after a delay.
633 // We've seen at least these in the wild from API providers:
634 // * 500 Internal Server Error
635 // * 502 Bad Gateway
636 // * 529 Service Overloaded
637
638 if retries_remaining == 0 {
639 let mut body = String::new();
640 response.body_mut().read_to_string(&mut body).await?;
641 anyhow::bail!(
642 "cloud language model completion failed after {} retries with status {status}: {body}",
643 Self::MAX_RETRIES
644 );
645 }
646
647 Timer::after(retry_delay).await;
648
649 retries_remaining -= 1;
650 retry_delay *= 2; // If it fails again, wait longer.
651 } else if status == StatusCode::PAYMENT_REQUIRED {
652 return Err(anyhow!(PaymentRequiredError));
653 } else {
654 let mut body = String::new();
655 response.body_mut().read_to_string(&mut body).await?;
656 return Err(anyhow!(ApiError { status, body }));
657 }
658 }
659 }
660}
661
662#[derive(Debug, Error)]
663#[error("cloud language model request failed with status {status}: {body}")]
664struct ApiError {
665 status: StatusCode,
666 body: String,
667}
668
669impl LanguageModel for CloudLanguageModel {
670 fn id(&self) -> LanguageModelId {
671 self.id.clone()
672 }
673
674 fn name(&self) -> LanguageModelName {
675 LanguageModelName::from(self.model.display_name.clone())
676 }
677
678 fn provider_id(&self) -> LanguageModelProviderId {
679 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
680 }
681
682 fn provider_name(&self) -> LanguageModelProviderName {
683 LanguageModelProviderName(PROVIDER_NAME.into())
684 }
685
686 fn supports_tools(&self) -> bool {
687 self.model.supports_tools
688 }
689
690 fn supports_images(&self) -> bool {
691 self.model.supports_images
692 }
693
694 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
695 match choice {
696 LanguageModelToolChoice::Auto
697 | LanguageModelToolChoice::Any
698 | LanguageModelToolChoice::None => true,
699 }
700 }
701
702 fn supports_burn_mode(&self) -> bool {
703 self.model.supports_max_mode
704 }
705
706 fn telemetry_id(&self) -> String {
707 format!("zed.dev/{}", self.model.id)
708 }
709
710 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
711 match self.model.provider {
712 zed_llm_client::LanguageModelProvider::Anthropic
713 | zed_llm_client::LanguageModelProvider::OpenAi => {
714 LanguageModelToolSchemaFormat::JsonSchema
715 }
716 zed_llm_client::LanguageModelProvider::Google => {
717 LanguageModelToolSchemaFormat::JsonSchemaSubset
718 }
719 }
720 }
721
722 fn max_token_count(&self) -> u64 {
723 self.model.max_token_count as u64
724 }
725
726 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
727 match &self.model.provider {
728 zed_llm_client::LanguageModelProvider::Anthropic => {
729 Some(LanguageModelCacheConfiguration {
730 min_total_token: 2_048,
731 should_speculate: true,
732 max_cache_anchors: 4,
733 })
734 }
735 zed_llm_client::LanguageModelProvider::OpenAi
736 | zed_llm_client::LanguageModelProvider::Google => None,
737 }
738 }
739
740 fn count_tokens(
741 &self,
742 request: LanguageModelRequest,
743 cx: &App,
744 ) -> BoxFuture<'static, Result<u64>> {
745 match self.model.provider {
746 zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
747 zed_llm_client::LanguageModelProvider::OpenAi => {
748 let model = match open_ai::Model::from_id(&self.model.id.0) {
749 Ok(model) => model,
750 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
751 };
752 count_open_ai_tokens(request, model, cx)
753 }
754 zed_llm_client::LanguageModelProvider::Google => {
755 let client = self.client.clone();
756 let llm_api_token = self.llm_api_token.clone();
757 let model_id = self.model.id.to_string();
758 let generate_content_request =
759 into_google(request, model_id.clone(), GoogleModelMode::Default);
760 async move {
761 let http_client = &client.http_client();
762 let token = llm_api_token.acquire(&client).await?;
763
764 let request_body = CountTokensBody {
765 provider: zed_llm_client::LanguageModelProvider::Google,
766 model: model_id,
767 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
768 generate_content_request,
769 })?,
770 };
771 let request = http_client::Request::builder()
772 .method(Method::POST)
773 .uri(
774 http_client
775 .build_zed_llm_url("/count_tokens", &[])?
776 .as_ref(),
777 )
778 .header("Content-Type", "application/json")
779 .header("Authorization", format!("Bearer {token}"))
780 .body(serde_json::to_string(&request_body)?.into())?;
781 let mut response = http_client.send(request).await?;
782 let status = response.status();
783 let mut response_body = String::new();
784 response
785 .body_mut()
786 .read_to_string(&mut response_body)
787 .await?;
788
789 if status.is_success() {
790 let response_body: CountTokensResponse =
791 serde_json::from_str(&response_body)?;
792
793 Ok(response_body.tokens as u64)
794 } else {
795 Err(anyhow!(ApiError {
796 status,
797 body: response_body
798 }))
799 }
800 }
801 .boxed()
802 }
803 }
804 }
805
806 fn stream_completion(
807 &self,
808 request: LanguageModelRequest,
809 cx: &AsyncApp,
810 ) -> BoxFuture<
811 'static,
812 Result<
813 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
814 LanguageModelCompletionError,
815 >,
816 > {
817 let thread_id = request.thread_id.clone();
818 let prompt_id = request.prompt_id.clone();
819 let intent = request.intent;
820 let mode = request.mode;
821 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
822 match self.model.provider {
823 zed_llm_client::LanguageModelProvider::Anthropic => {
824 let request = into_anthropic(
825 request,
826 self.model.id.to_string(),
827 1.0,
828 self.model.max_output_tokens as u64,
829 if self.model.id.0.ends_with("-thinking") {
830 AnthropicModelMode::Thinking {
831 budget_tokens: Some(4_096),
832 }
833 } else {
834 AnthropicModelMode::Default
835 },
836 );
837 let client = self.client.clone();
838 let llm_api_token = self.llm_api_token.clone();
839 let future = self.request_limiter.stream(async move {
840 let PerformLlmCompletionResponse {
841 response,
842 usage,
843 includes_status_messages,
844 tool_use_limit_reached,
845 } = Self::perform_llm_completion(
846 client.clone(),
847 llm_api_token,
848 app_version,
849 CompletionBody {
850 thread_id,
851 prompt_id,
852 intent,
853 mode,
854 provider: zed_llm_client::LanguageModelProvider::Anthropic,
855 model: request.model.clone(),
856 provider_request: serde_json::to_value(&request)
857 .map_err(|e| anyhow!(e))?,
858 },
859 )
860 .await
861 .map_err(|err| match err.downcast::<ApiError>() {
862 Ok(api_err) => {
863 if api_err.status == StatusCode::BAD_REQUEST {
864 if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
865 return anyhow!(
866 LanguageModelKnownError::ContextWindowLimitExceeded {
867 tokens
868 }
869 );
870 }
871 }
872 anyhow!(api_err)
873 }
874 Err(err) => anyhow!(err),
875 })?;
876
877 let mut mapper = AnthropicEventMapper::new();
878 Ok(map_cloud_completion_events(
879 Box::pin(
880 response_lines(response, includes_status_messages)
881 .chain(usage_updated_event(usage))
882 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
883 ),
884 move |event| mapper.map_event(event),
885 ))
886 });
887 async move { Ok(future.await?.boxed()) }.boxed()
888 }
889 zed_llm_client::LanguageModelProvider::OpenAi => {
890 let client = self.client.clone();
891 let model = match open_ai::Model::from_id(&self.model.id.0) {
892 Ok(model) => model,
893 Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
894 };
895 let request = into_open_ai(
896 request,
897 model.id(),
898 model.supports_parallel_tool_calls(),
899 None,
900 );
901 let llm_api_token = self.llm_api_token.clone();
902 let future = self.request_limiter.stream(async move {
903 let PerformLlmCompletionResponse {
904 response,
905 usage,
906 includes_status_messages,
907 tool_use_limit_reached,
908 } = Self::perform_llm_completion(
909 client.clone(),
910 llm_api_token,
911 app_version,
912 CompletionBody {
913 thread_id,
914 prompt_id,
915 intent,
916 mode,
917 provider: zed_llm_client::LanguageModelProvider::OpenAi,
918 model: request.model.clone(),
919 provider_request: serde_json::to_value(&request)
920 .map_err(|e| anyhow!(e))?,
921 },
922 )
923 .await?;
924
925 let mut mapper = OpenAiEventMapper::new();
926 Ok(map_cloud_completion_events(
927 Box::pin(
928 response_lines(response, includes_status_messages)
929 .chain(usage_updated_event(usage))
930 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
931 ),
932 move |event| mapper.map_event(event),
933 ))
934 });
935 async move { Ok(future.await?.boxed()) }.boxed()
936 }
937 zed_llm_client::LanguageModelProvider::Google => {
938 let client = self.client.clone();
939 let request =
940 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
941 let llm_api_token = self.llm_api_token.clone();
942 let future = self.request_limiter.stream(async move {
943 let PerformLlmCompletionResponse {
944 response,
945 usage,
946 includes_status_messages,
947 tool_use_limit_reached,
948 } = Self::perform_llm_completion(
949 client.clone(),
950 llm_api_token,
951 app_version,
952 CompletionBody {
953 thread_id,
954 prompt_id,
955 intent,
956 mode,
957 provider: zed_llm_client::LanguageModelProvider::Google,
958 model: request.model.model_id.clone(),
959 provider_request: serde_json::to_value(&request)
960 .map_err(|e| anyhow!(e))?,
961 },
962 )
963 .await?;
964
965 let mut mapper = GoogleEventMapper::new();
966 Ok(map_cloud_completion_events(
967 Box::pin(
968 response_lines(response, includes_status_messages)
969 .chain(usage_updated_event(usage))
970 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
971 ),
972 move |event| mapper.map_event(event),
973 ))
974 });
975 async move { Ok(future.await?.boxed()) }.boxed()
976 }
977 }
978 }
979}
980
981#[derive(Serialize, Deserialize)]
982#[serde(rename_all = "snake_case")]
983pub enum CloudCompletionEvent<T> {
984 Status(CompletionRequestStatus),
985 Event(T),
986}
987
988fn map_cloud_completion_events<T, F>(
989 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
990 mut map_callback: F,
991) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
992where
993 T: DeserializeOwned + 'static,
994 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
995 + Send
996 + 'static,
997{
998 stream
999 .flat_map(move |event| {
1000 futures::stream::iter(match event {
1001 Err(error) => {
1002 vec![Err(LanguageModelCompletionError::Other(error))]
1003 }
1004 Ok(CloudCompletionEvent::Status(event)) => {
1005 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1006 }
1007 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1008 })
1009 })
1010 .boxed()
1011}
1012
1013fn usage_updated_event<T>(
1014 usage: Option<ModelRequestUsage>,
1015) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1016 futures::stream::iter(usage.map(|usage| {
1017 Ok(CloudCompletionEvent::Status(
1018 CompletionRequestStatus::UsageUpdated {
1019 amount: usage.amount as usize,
1020 limit: usage.limit,
1021 },
1022 ))
1023 }))
1024}
1025
1026fn tool_use_limit_reached_event<T>(
1027 tool_use_limit_reached: bool,
1028) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1029 futures::stream::iter(tool_use_limit_reached.then(|| {
1030 Ok(CloudCompletionEvent::Status(
1031 CompletionRequestStatus::ToolUseLimitReached,
1032 ))
1033 }))
1034}
1035
1036fn response_lines<T: DeserializeOwned>(
1037 response: Response<AsyncBody>,
1038 includes_status_messages: bool,
1039) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1040 futures::stream::try_unfold(
1041 (String::new(), BufReader::new(response.into_body())),
1042 move |(mut line, mut body)| async move {
1043 match body.read_line(&mut line).await {
1044 Ok(0) => Ok(None),
1045 Ok(_) => {
1046 let event = if includes_status_messages {
1047 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1048 } else {
1049 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1050 };
1051
1052 line.clear();
1053 Ok(Some((event, (line, body))))
1054 }
1055 Err(e) => Err(e.into()),
1056 }
1057 },
1058 )
1059}
1060
1061struct ConfigurationView {
1062 state: gpui::Entity<State>,
1063}
1064
1065impl ConfigurationView {
1066 fn authenticate(&mut self, cx: &mut Context<Self>) {
1067 self.state.update(cx, |state, cx| {
1068 state.authenticate(cx).detach_and_log_err(cx);
1069 });
1070 cx.notify();
1071 }
1072}
1073
1074impl Render for ConfigurationView {
1075 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1076 const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1077
1078 let is_connected = !self.state.read(cx).is_signed_out();
1079 let user_store = self.state.read(cx).user_store.read(cx);
1080 let plan = user_store.current_plan();
1081 let subscription_period = user_store.subscription_period();
1082 let eligible_for_trial = user_store.trial_started_at().is_none();
1083 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1084
1085 let is_pro = plan == Some(proto::Plan::ZedPro);
1086 let subscription_text = match (plan, subscription_period) {
1087 (Some(proto::Plan::ZedPro), Some(_)) => {
1088 "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1089 }
1090 (Some(proto::Plan::ZedProTrial), Some(_)) => {
1091 "You have access to Zed's hosted LLMs through your Zed Pro trial."
1092 }
1093 (Some(proto::Plan::Free), Some(_)) => {
1094 "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1095 }
1096 _ => {
1097 if eligible_for_trial {
1098 "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1099 } else {
1100 "Subscribe for access to Zed's hosted LLMs."
1101 }
1102 }
1103 };
1104 let manage_subscription_buttons = if is_pro {
1105 h_flex().child(
1106 Button::new("manage_settings", "Manage Subscription")
1107 .style(ButtonStyle::Tinted(TintColor::Accent))
1108 .on_click(cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx)))),
1109 )
1110 } else {
1111 h_flex()
1112 .gap_2()
1113 .child(
1114 Button::new("learn_more", "Learn more")
1115 .style(ButtonStyle::Subtle)
1116 .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_PRICING_URL))),
1117 )
1118 .child(
1119 Button::new("upgrade", "Upgrade")
1120 .style(ButtonStyle::Subtle)
1121 .color(Color::Accent)
1122 .on_click(
1123 cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1124 ),
1125 )
1126 };
1127
1128 if is_connected {
1129 v_flex()
1130 .gap_3()
1131 .w_full()
1132 .children(render_accept_terms(
1133 self.state.clone(),
1134 LanguageModelProviderTosView::Configuration,
1135 cx,
1136 ))
1137 .when(has_accepted_terms, |this| {
1138 this.child(subscription_text)
1139 .child(manage_subscription_buttons)
1140 })
1141 } else {
1142 v_flex()
1143 .gap_2()
1144 .child(Label::new("Use Zed AI to access hosted language models."))
1145 .child(
1146 Button::new("sign_in", "Sign In")
1147 .icon_color(Color::Muted)
1148 .icon(IconName::Github)
1149 .icon_position(IconPosition::Start)
1150 .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1151 )
1152 }
1153 }
1154}