1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use client::{Client, ModelRequestUsage, UserStore, zed_urls};
4use futures::{
5 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
6};
7use google_ai::GoogleModelMode;
8use gpui::{
9 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
10};
11use http_client::http::{HeaderMap, HeaderValue};
12use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
13use language_model::{
14 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
15 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
16 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
17 LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
18 LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
19 ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
20};
21use proto::Plan;
22use release_channel::AppVersion;
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use settings::SettingsStore;
26use smol::io::{AsyncReadExt, BufReader};
27use std::pin::Pin;
28use std::str::FromStr as _;
29use std::sync::Arc;
30use std::time::Duration;
31use thiserror::Error;
32use ui::{TintColor, prelude::*};
33use util::{ResultExt as _, maybe};
34use zed_llm_client::{
35 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
36 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
37 ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
38 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
39 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
40};
41
42use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
43use crate::provider::google::{GoogleEventMapper, into_google};
44use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
45
46const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
47const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
48
49#[derive(Default, Clone, Debug, PartialEq)]
50pub struct ZedDotDevSettings {
51 pub available_models: Vec<AvailableModel>,
52}
53
54#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
55#[serde(rename_all = "lowercase")]
56pub enum AvailableProvider {
57 Anthropic,
58 OpenAi,
59 Google,
60}
61
62#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
63pub struct AvailableModel {
64 /// The provider of the language model.
65 pub provider: AvailableProvider,
66 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
67 pub name: String,
68 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
69 pub display_name: Option<String>,
70 /// The size of the context window, indicating the maximum number of tokens the model can process.
71 pub max_tokens: usize,
72 /// The maximum number of output tokens allowed by the model.
73 pub max_output_tokens: Option<u64>,
74 /// The maximum number of completion tokens allowed by the model (o1-* only)
75 pub max_completion_tokens: Option<u64>,
76 /// Override this model with a different Anthropic model for tool calls.
77 pub tool_override: Option<String>,
78 /// Indicates whether this custom model supports caching.
79 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
80 /// The default temperature to use for this model.
81 pub default_temperature: Option<f32>,
82 /// Any extra beta headers to provide when using the model.
83 #[serde(default)]
84 pub extra_beta_headers: Vec<String>,
85 /// The model's mode (e.g. thinking)
86 pub mode: Option<ModelMode>,
87}
88
89#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
90#[serde(tag = "type", rename_all = "lowercase")]
91pub enum ModelMode {
92 #[default]
93 Default,
94 Thinking {
95 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
96 budget_tokens: Option<u32>,
97 },
98}
99
100impl From<ModelMode> for AnthropicModelMode {
101 fn from(value: ModelMode) -> Self {
102 match value {
103 ModelMode::Default => AnthropicModelMode::Default,
104 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
105 }
106 }
107}
108
109pub struct CloudLanguageModelProvider {
110 client: Arc<Client>,
111 state: gpui::Entity<State>,
112 _maintain_client_status: Task<()>,
113}
114
115pub struct State {
116 client: Arc<Client>,
117 llm_api_token: LlmApiToken,
118 user_store: Entity<UserStore>,
119 status: client::Status,
120 accept_terms: Option<Task<Result<()>>>,
121 models: Vec<Arc<zed_llm_client::LanguageModel>>,
122 default_model: Option<Arc<zed_llm_client::LanguageModel>>,
123 default_fast_model: Option<Arc<zed_llm_client::LanguageModel>>,
124 recommended_models: Vec<Arc<zed_llm_client::LanguageModel>>,
125 _fetch_models_task: Task<()>,
126 _settings_subscription: Subscription,
127 _llm_token_subscription: Subscription,
128}
129
130impl State {
131 fn new(
132 client: Arc<Client>,
133 user_store: Entity<UserStore>,
134 status: client::Status,
135 cx: &mut Context<Self>,
136 ) -> Self {
137 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
138
139 Self {
140 client: client.clone(),
141 llm_api_token: LlmApiToken::default(),
142 user_store,
143 status,
144 accept_terms: None,
145 models: Vec::new(),
146 default_model: None,
147 default_fast_model: None,
148 recommended_models: Vec::new(),
149 _fetch_models_task: cx.spawn(async move |this, cx| {
150 maybe!(async move {
151 let (client, llm_api_token) = this
152 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
153
154 loop {
155 let status = this.read_with(cx, |this, _cx| this.status)?;
156 if matches!(status, client::Status::Connected { .. }) {
157 break;
158 }
159
160 cx.background_executor()
161 .timer(Duration::from_millis(100))
162 .await;
163 }
164
165 let response = Self::fetch_models(client, llm_api_token).await?;
166 cx.update(|cx| {
167 this.update(cx, |this, cx| {
168 let mut models = Vec::new();
169
170 for model in response.models {
171 models.push(Arc::new(model.clone()));
172
173 // Right now we represent thinking variants of models as separate models on the client,
174 // so we need to insert variants for any model that supports thinking.
175 if model.supports_thinking {
176 models.push(Arc::new(zed_llm_client::LanguageModel {
177 id: zed_llm_client::LanguageModelId(
178 format!("{}-thinking", model.id).into(),
179 ),
180 display_name: format!("{} Thinking", model.display_name),
181 ..model
182 }));
183 }
184 }
185
186 this.default_model = models
187 .iter()
188 .find(|model| model.id == response.default_model)
189 .cloned();
190 this.default_fast_model = models
191 .iter()
192 .find(|model| model.id == response.default_fast_model)
193 .cloned();
194 this.recommended_models = response
195 .recommended_models
196 .iter()
197 .filter_map(|id| models.iter().find(|model| &model.id == id))
198 .cloned()
199 .collect();
200 this.models = models;
201 cx.notify();
202 })
203 })??;
204
205 anyhow::Ok(())
206 })
207 .await
208 .context("failed to fetch Zed models")
209 .log_err();
210 }),
211 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
212 cx.notify();
213 }),
214 _llm_token_subscription: cx.subscribe(
215 &refresh_llm_token_listener,
216 |this, _listener, _event, cx| {
217 let client = this.client.clone();
218 let llm_api_token = this.llm_api_token.clone();
219 cx.spawn(async move |_this, _cx| {
220 llm_api_token.refresh(&client).await?;
221 anyhow::Ok(())
222 })
223 .detach_and_log_err(cx);
224 },
225 ),
226 }
227 }
228
229 fn is_signed_out(&self) -> bool {
230 self.status.is_signed_out()
231 }
232
233 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
234 let client = self.client.clone();
235 cx.spawn(async move |state, cx| {
236 client
237 .authenticate_and_connect(true, &cx)
238 .await
239 .into_response()?;
240 state.update(cx, |_, cx| cx.notify())
241 })
242 }
243
244 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
245 self.user_store
246 .read(cx)
247 .current_user_has_accepted_terms()
248 .unwrap_or(false)
249 }
250
251 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
252 let user_store = self.user_store.clone();
253 self.accept_terms = Some(cx.spawn(async move |this, cx| {
254 let _ = user_store
255 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
256 .await;
257 this.update(cx, |this, cx| {
258 this.accept_terms = None;
259 cx.notify()
260 })
261 }));
262 }
263
264 async fn fetch_models(
265 client: Arc<Client>,
266 llm_api_token: LlmApiToken,
267 ) -> Result<ListModelsResponse> {
268 let http_client = &client.http_client();
269 let token = llm_api_token.acquire(&client).await?;
270
271 let request = http_client::Request::builder()
272 .method(Method::GET)
273 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
274 .header("Authorization", format!("Bearer {token}"))
275 .body(AsyncBody::empty())?;
276 let mut response = http_client
277 .send(request)
278 .await
279 .context("failed to send list models request")?;
280
281 if response.status().is_success() {
282 let mut body = String::new();
283 response.body_mut().read_to_string(&mut body).await?;
284 return Ok(serde_json::from_str(&body)?);
285 } else {
286 let mut body = String::new();
287 response.body_mut().read_to_string(&mut body).await?;
288 anyhow::bail!(
289 "error listing models.\nStatus: {:?}\nBody: {body}",
290 response.status(),
291 );
292 }
293 }
294}
295
296impl CloudLanguageModelProvider {
297 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
298 let mut status_rx = client.status();
299 let status = *status_rx.borrow();
300
301 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
302
303 let state_ref = state.downgrade();
304 let maintain_client_status = cx.spawn(async move |cx| {
305 while let Some(status) = status_rx.next().await {
306 if let Some(this) = state_ref.upgrade() {
307 _ = this.update(cx, |this, cx| {
308 if this.status != status {
309 this.status = status;
310 cx.notify();
311 }
312 });
313 } else {
314 break;
315 }
316 }
317 });
318
319 Self {
320 client,
321 state: state.clone(),
322 _maintain_client_status: maintain_client_status,
323 }
324 }
325
326 fn create_language_model(
327 &self,
328 model: Arc<zed_llm_client::LanguageModel>,
329 llm_api_token: LlmApiToken,
330 ) -> Arc<dyn LanguageModel> {
331 Arc::new(CloudLanguageModel {
332 id: LanguageModelId(SharedString::from(model.id.0.clone())),
333 model,
334 llm_api_token: llm_api_token.clone(),
335 client: self.client.clone(),
336 request_limiter: RateLimiter::new(4),
337 })
338 }
339}
340
341impl LanguageModelProviderState for CloudLanguageModelProvider {
342 type ObservableEntity = State;
343
344 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
345 Some(self.state.clone())
346 }
347}
348
349impl LanguageModelProvider for CloudLanguageModelProvider {
350 fn id(&self) -> LanguageModelProviderId {
351 PROVIDER_ID
352 }
353
354 fn name(&self) -> LanguageModelProviderName {
355 PROVIDER_NAME
356 }
357
358 fn icon(&self) -> IconName {
359 IconName::AiZed
360 }
361
362 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
363 let default_model = self.state.read(cx).default_model.clone()?;
364 let llm_api_token = self.state.read(cx).llm_api_token.clone();
365 Some(self.create_language_model(default_model, llm_api_token))
366 }
367
368 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
369 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
370 let llm_api_token = self.state.read(cx).llm_api_token.clone();
371 Some(self.create_language_model(default_fast_model, llm_api_token))
372 }
373
374 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
375 let llm_api_token = self.state.read(cx).llm_api_token.clone();
376 self.state
377 .read(cx)
378 .recommended_models
379 .iter()
380 .cloned()
381 .map(|model| self.create_language_model(model, llm_api_token.clone()))
382 .collect()
383 }
384
385 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
386 let llm_api_token = self.state.read(cx).llm_api_token.clone();
387 self.state
388 .read(cx)
389 .models
390 .iter()
391 .cloned()
392 .map(|model| self.create_language_model(model, llm_api_token.clone()))
393 .collect()
394 }
395
396 fn is_authenticated(&self, cx: &App) -> bool {
397 !self.state.read(cx).is_signed_out()
398 }
399
400 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
401 Task::ready(Ok(()))
402 }
403
404 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
405 cx.new(|_| ConfigurationView {
406 state: self.state.clone(),
407 })
408 .into()
409 }
410
411 fn must_accept_terms(&self, cx: &App) -> bool {
412 !self.state.read(cx).has_accepted_terms_of_service(cx)
413 }
414
415 fn render_accept_terms(
416 &self,
417 view: LanguageModelProviderTosView,
418 cx: &mut App,
419 ) -> Option<AnyElement> {
420 render_accept_terms(self.state.clone(), view, cx)
421 }
422
423 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
424 Task::ready(Ok(()))
425 }
426}
427
428fn render_accept_terms(
429 state: Entity<State>,
430 view_kind: LanguageModelProviderTosView,
431 cx: &mut App,
432) -> Option<AnyElement> {
433 if state.read(cx).has_accepted_terms_of_service(cx) {
434 return None;
435 }
436
437 let accept_terms_disabled = state.read(cx).accept_terms.is_some();
438
439 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
440 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
441
442 let terms_button = Button::new("terms_of_service", "Terms of Service")
443 .style(ButtonStyle::Subtle)
444 .icon(IconName::ArrowUpRight)
445 .icon_color(Color::Muted)
446 .icon_size(IconSize::XSmall)
447 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
448 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
449
450 let button_container = h_flex().child(
451 Button::new("accept_terms", "I accept the Terms of Service")
452 .when(!thread_empty_state, |this| {
453 this.full_width()
454 .style(ButtonStyle::Tinted(TintColor::Accent))
455 .icon(IconName::Check)
456 .icon_position(IconPosition::Start)
457 .icon_size(IconSize::Small)
458 })
459 .when(thread_empty_state, |this| {
460 this.style(ButtonStyle::Tinted(TintColor::Warning))
461 .label_size(LabelSize::Small)
462 })
463 .disabled(accept_terms_disabled)
464 .on_click({
465 let state = state.downgrade();
466 move |_, _window, cx| {
467 state
468 .update(cx, |state, cx| state.accept_terms_of_service(cx))
469 .ok();
470 }
471 }),
472 );
473
474 let form = if thread_empty_state {
475 h_flex()
476 .w_full()
477 .flex_wrap()
478 .justify_between()
479 .child(
480 h_flex()
481 .child(
482 Label::new("To start using Zed AI, please read and accept the")
483 .size(LabelSize::Small),
484 )
485 .child(terms_button),
486 )
487 .child(button_container)
488 } else {
489 v_flex()
490 .w_full()
491 .gap_2()
492 .child(
493 h_flex()
494 .flex_wrap()
495 .when(thread_fresh_start, |this| this.justify_center())
496 .child(Label::new(
497 "To start using Zed AI, please read and accept the",
498 ))
499 .child(terms_button),
500 )
501 .child({
502 match view_kind {
503 LanguageModelProviderTosView::PromptEditorPopup => {
504 button_container.w_full().justify_end()
505 }
506 LanguageModelProviderTosView::Configuration => {
507 button_container.w_full().justify_start()
508 }
509 LanguageModelProviderTosView::ThreadFreshStart => {
510 button_container.w_full().justify_center()
511 }
512 LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
513 }
514 })
515 };
516
517 Some(form.into_any())
518}
519
520pub struct CloudLanguageModel {
521 id: LanguageModelId,
522 model: Arc<zed_llm_client::LanguageModel>,
523 llm_api_token: LlmApiToken,
524 client: Arc<Client>,
525 request_limiter: RateLimiter,
526}
527
528struct PerformLlmCompletionResponse {
529 response: Response<AsyncBody>,
530 usage: Option<ModelRequestUsage>,
531 tool_use_limit_reached: bool,
532 includes_status_messages: bool,
533}
534
535impl CloudLanguageModel {
536 async fn perform_llm_completion(
537 client: Arc<Client>,
538 llm_api_token: LlmApiToken,
539 app_version: Option<SemanticVersion>,
540 body: CompletionBody,
541 ) -> Result<PerformLlmCompletionResponse> {
542 let http_client = &client.http_client();
543
544 let mut token = llm_api_token.acquire(&client).await?;
545 let mut refreshed_token = false;
546
547 loop {
548 let request_builder = http_client::Request::builder()
549 .method(Method::POST)
550 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
551 let request_builder = if let Some(app_version) = app_version {
552 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
553 } else {
554 request_builder
555 };
556
557 let request = request_builder
558 .header("Content-Type", "application/json")
559 .header("Authorization", format!("Bearer {token}"))
560 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
561 .body(serde_json::to_string(&body)?.into())?;
562 let mut response = http_client.send(request).await?;
563 let status = response.status();
564 if status.is_success() {
565 let includes_status_messages = response
566 .headers()
567 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
568 .is_some();
569
570 let tool_use_limit_reached = response
571 .headers()
572 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
573 .is_some();
574
575 let usage = if includes_status_messages {
576 None
577 } else {
578 ModelRequestUsage::from_headers(response.headers()).ok()
579 };
580
581 return Ok(PerformLlmCompletionResponse {
582 response,
583 usage,
584 includes_status_messages,
585 tool_use_limit_reached,
586 });
587 }
588
589 if !refreshed_token
590 && response
591 .headers()
592 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
593 .is_some()
594 {
595 token = llm_api_token.refresh(&client).await?;
596 refreshed_token = true;
597 continue;
598 }
599
600 if status == StatusCode::FORBIDDEN
601 && response
602 .headers()
603 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
604 .is_some()
605 {
606 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
607 .headers()
608 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
609 .and_then(|resource| resource.to_str().ok())
610 {
611 if let Some(plan) = response
612 .headers()
613 .get(CURRENT_PLAN_HEADER_NAME)
614 .and_then(|plan| plan.to_str().ok())
615 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
616 {
617 let plan = match plan {
618 zed_llm_client::Plan::ZedFree => Plan::Free,
619 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
620 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
621 };
622 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
623 }
624 }
625 } else if status == StatusCode::PAYMENT_REQUIRED {
626 return Err(anyhow!(PaymentRequiredError));
627 }
628
629 let mut body = String::new();
630 let headers = response.headers().clone();
631 response.body_mut().read_to_string(&mut body).await?;
632 return Err(anyhow!(ApiError {
633 status,
634 body,
635 headers
636 }));
637 }
638 }
639}
640
641#[derive(Debug, Error)]
642#[error("cloud language model request failed with status {status}: {body}")]
643struct ApiError {
644 status: StatusCode,
645 body: String,
646 headers: HeaderMap<HeaderValue>,
647}
648
649impl From<ApiError> for LanguageModelCompletionError {
650 fn from(error: ApiError) -> Self {
651 let retry_after = None;
652 LanguageModelCompletionError::from_http_status(
653 PROVIDER_NAME,
654 error.status,
655 error.body,
656 retry_after,
657 )
658 }
659}
660
661impl LanguageModel for CloudLanguageModel {
662 fn id(&self) -> LanguageModelId {
663 self.id.clone()
664 }
665
666 fn name(&self) -> LanguageModelName {
667 LanguageModelName::from(self.model.display_name.clone())
668 }
669
670 fn provider_id(&self) -> LanguageModelProviderId {
671 PROVIDER_ID
672 }
673
674 fn provider_name(&self) -> LanguageModelProviderName {
675 PROVIDER_NAME
676 }
677
678 fn upstream_provider_id(&self) -> LanguageModelProviderId {
679 use zed_llm_client::LanguageModelProvider::*;
680 match self.model.provider {
681 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
682 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
683 Google => language_model::GOOGLE_PROVIDER_ID,
684 }
685 }
686
687 fn upstream_provider_name(&self) -> LanguageModelProviderName {
688 use zed_llm_client::LanguageModelProvider::*;
689 match self.model.provider {
690 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
691 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
692 Google => language_model::GOOGLE_PROVIDER_NAME,
693 }
694 }
695
696 fn supports_tools(&self) -> bool {
697 self.model.supports_tools
698 }
699
700 fn supports_images(&self) -> bool {
701 self.model.supports_images
702 }
703
704 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
705 match choice {
706 LanguageModelToolChoice::Auto
707 | LanguageModelToolChoice::Any
708 | LanguageModelToolChoice::None => true,
709 }
710 }
711
712 fn supports_burn_mode(&self) -> bool {
713 self.model.supports_max_mode
714 }
715
716 fn telemetry_id(&self) -> String {
717 format!("zed.dev/{}", self.model.id)
718 }
719
720 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
721 match self.model.provider {
722 zed_llm_client::LanguageModelProvider::Anthropic
723 | zed_llm_client::LanguageModelProvider::OpenAi => {
724 LanguageModelToolSchemaFormat::JsonSchema
725 }
726 zed_llm_client::LanguageModelProvider::Google => {
727 LanguageModelToolSchemaFormat::JsonSchemaSubset
728 }
729 }
730 }
731
732 fn max_token_count(&self) -> u64 {
733 self.model.max_token_count as u64
734 }
735
736 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
737 match &self.model.provider {
738 zed_llm_client::LanguageModelProvider::Anthropic => {
739 Some(LanguageModelCacheConfiguration {
740 min_total_token: 2_048,
741 should_speculate: true,
742 max_cache_anchors: 4,
743 })
744 }
745 zed_llm_client::LanguageModelProvider::OpenAi
746 | zed_llm_client::LanguageModelProvider::Google => None,
747 }
748 }
749
750 fn count_tokens(
751 &self,
752 request: LanguageModelRequest,
753 cx: &App,
754 ) -> BoxFuture<'static, Result<u64>> {
755 match self.model.provider {
756 zed_llm_client::LanguageModelProvider::Anthropic => count_anthropic_tokens(request, cx),
757 zed_llm_client::LanguageModelProvider::OpenAi => {
758 let model = match open_ai::Model::from_id(&self.model.id.0) {
759 Ok(model) => model,
760 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
761 };
762 count_open_ai_tokens(request, model, cx)
763 }
764 zed_llm_client::LanguageModelProvider::Google => {
765 let client = self.client.clone();
766 let llm_api_token = self.llm_api_token.clone();
767 let model_id = self.model.id.to_string();
768 let generate_content_request =
769 into_google(request, model_id.clone(), GoogleModelMode::Default);
770 async move {
771 let http_client = &client.http_client();
772 let token = llm_api_token.acquire(&client).await?;
773
774 let request_body = CountTokensBody {
775 provider: zed_llm_client::LanguageModelProvider::Google,
776 model: model_id,
777 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
778 generate_content_request,
779 })?,
780 };
781 let request = http_client::Request::builder()
782 .method(Method::POST)
783 .uri(
784 http_client
785 .build_zed_llm_url("/count_tokens", &[])?
786 .as_ref(),
787 )
788 .header("Content-Type", "application/json")
789 .header("Authorization", format!("Bearer {token}"))
790 .body(serde_json::to_string(&request_body)?.into())?;
791 let mut response = http_client.send(request).await?;
792 let status = response.status();
793 let headers = response.headers().clone();
794 let mut response_body = String::new();
795 response
796 .body_mut()
797 .read_to_string(&mut response_body)
798 .await?;
799
800 if status.is_success() {
801 let response_body: CountTokensResponse =
802 serde_json::from_str(&response_body)?;
803
804 Ok(response_body.tokens as u64)
805 } else {
806 Err(anyhow!(ApiError {
807 status,
808 body: response_body,
809 headers
810 }))
811 }
812 }
813 .boxed()
814 }
815 }
816 }
817
818 fn stream_completion(
819 &self,
820 request: LanguageModelRequest,
821 cx: &AsyncApp,
822 ) -> BoxFuture<
823 'static,
824 Result<
825 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
826 LanguageModelCompletionError,
827 >,
828 > {
829 let thread_id = request.thread_id.clone();
830 let prompt_id = request.prompt_id.clone();
831 let intent = request.intent;
832 let mode = request.mode;
833 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
834 match self.model.provider {
835 zed_llm_client::LanguageModelProvider::Anthropic => {
836 let request = into_anthropic(
837 request,
838 self.model.id.to_string(),
839 1.0,
840 self.model.max_output_tokens as u64,
841 if self.model.id.0.ends_with("-thinking") {
842 AnthropicModelMode::Thinking {
843 budget_tokens: Some(4_096),
844 }
845 } else {
846 AnthropicModelMode::Default
847 },
848 );
849 let client = self.client.clone();
850 let llm_api_token = self.llm_api_token.clone();
851 let future = self.request_limiter.stream(async move {
852 let PerformLlmCompletionResponse {
853 response,
854 usage,
855 includes_status_messages,
856 tool_use_limit_reached,
857 } = Self::perform_llm_completion(
858 client.clone(),
859 llm_api_token,
860 app_version,
861 CompletionBody {
862 thread_id,
863 prompt_id,
864 intent,
865 mode,
866 provider: zed_llm_client::LanguageModelProvider::Anthropic,
867 model: request.model.clone(),
868 provider_request: serde_json::to_value(&request)
869 .map_err(|e| anyhow!(e))?,
870 },
871 )
872 .await
873 .map_err(|err| match err.downcast::<ApiError>() {
874 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
875 Err(err) => anyhow!(err),
876 })?;
877
878 let mut mapper = AnthropicEventMapper::new();
879 Ok(map_cloud_completion_events(
880 Box::pin(
881 response_lines(response, includes_status_messages)
882 .chain(usage_updated_event(usage))
883 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
884 ),
885 move |event| mapper.map_event(event),
886 ))
887 });
888 async move { Ok(future.await?.boxed()) }.boxed()
889 }
890 zed_llm_client::LanguageModelProvider::OpenAi => {
891 let client = self.client.clone();
892 let model = match open_ai::Model::from_id(&self.model.id.0) {
893 Ok(model) => model,
894 Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
895 };
896 let request = into_open_ai(
897 request,
898 model.id(),
899 model.supports_parallel_tool_calls(),
900 None,
901 );
902 let llm_api_token = self.llm_api_token.clone();
903 let future = self.request_limiter.stream(async move {
904 let PerformLlmCompletionResponse {
905 response,
906 usage,
907 includes_status_messages,
908 tool_use_limit_reached,
909 } = Self::perform_llm_completion(
910 client.clone(),
911 llm_api_token,
912 app_version,
913 CompletionBody {
914 thread_id,
915 prompt_id,
916 intent,
917 mode,
918 provider: zed_llm_client::LanguageModelProvider::OpenAi,
919 model: request.model.clone(),
920 provider_request: serde_json::to_value(&request)
921 .map_err(|e| anyhow!(e))?,
922 },
923 )
924 .await?;
925
926 let mut mapper = OpenAiEventMapper::new();
927 Ok(map_cloud_completion_events(
928 Box::pin(
929 response_lines(response, includes_status_messages)
930 .chain(usage_updated_event(usage))
931 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
932 ),
933 move |event| mapper.map_event(event),
934 ))
935 });
936 async move { Ok(future.await?.boxed()) }.boxed()
937 }
938 zed_llm_client::LanguageModelProvider::Google => {
939 let client = self.client.clone();
940 let request =
941 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
942 let llm_api_token = self.llm_api_token.clone();
943 let future = self.request_limiter.stream(async move {
944 let PerformLlmCompletionResponse {
945 response,
946 usage,
947 includes_status_messages,
948 tool_use_limit_reached,
949 } = Self::perform_llm_completion(
950 client.clone(),
951 llm_api_token,
952 app_version,
953 CompletionBody {
954 thread_id,
955 prompt_id,
956 intent,
957 mode,
958 provider: zed_llm_client::LanguageModelProvider::Google,
959 model: request.model.model_id.clone(),
960 provider_request: serde_json::to_value(&request)
961 .map_err(|e| anyhow!(e))?,
962 },
963 )
964 .await?;
965
966 let mut mapper = GoogleEventMapper::new();
967 Ok(map_cloud_completion_events(
968 Box::pin(
969 response_lines(response, includes_status_messages)
970 .chain(usage_updated_event(usage))
971 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
972 ),
973 move |event| mapper.map_event(event),
974 ))
975 });
976 async move { Ok(future.await?.boxed()) }.boxed()
977 }
978 }
979 }
980}
981
982#[derive(Serialize, Deserialize)]
983#[serde(rename_all = "snake_case")]
984pub enum CloudCompletionEvent<T> {
985 Status(CompletionRequestStatus),
986 Event(T),
987}
988
989fn map_cloud_completion_events<T, F>(
990 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
991 mut map_callback: F,
992) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
993where
994 T: DeserializeOwned + 'static,
995 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
996 + Send
997 + 'static,
998{
999 stream
1000 .flat_map(move |event| {
1001 futures::stream::iter(match event {
1002 Err(error) => {
1003 vec![Err(LanguageModelCompletionError::from(error))]
1004 }
1005 Ok(CloudCompletionEvent::Status(event)) => {
1006 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1007 }
1008 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
1009 })
1010 })
1011 .boxed()
1012}
1013
1014fn usage_updated_event<T>(
1015 usage: Option<ModelRequestUsage>,
1016) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1017 futures::stream::iter(usage.map(|usage| {
1018 Ok(CloudCompletionEvent::Status(
1019 CompletionRequestStatus::UsageUpdated {
1020 amount: usage.amount as usize,
1021 limit: usage.limit,
1022 },
1023 ))
1024 }))
1025}
1026
1027fn tool_use_limit_reached_event<T>(
1028 tool_use_limit_reached: bool,
1029) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1030 futures::stream::iter(tool_use_limit_reached.then(|| {
1031 Ok(CloudCompletionEvent::Status(
1032 CompletionRequestStatus::ToolUseLimitReached,
1033 ))
1034 }))
1035}
1036
1037fn response_lines<T: DeserializeOwned>(
1038 response: Response<AsyncBody>,
1039 includes_status_messages: bool,
1040) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
1041 futures::stream::try_unfold(
1042 (String::new(), BufReader::new(response.into_body())),
1043 move |(mut line, mut body)| async move {
1044 match body.read_line(&mut line).await {
1045 Ok(0) => Ok(None),
1046 Ok(_) => {
1047 let event = if includes_status_messages {
1048 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
1049 } else {
1050 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1051 };
1052
1053 line.clear();
1054 Ok(Some((event, (line, body))))
1055 }
1056 Err(e) => Err(e.into()),
1057 }
1058 },
1059 )
1060}
1061
1062struct ConfigurationView {
1063 state: gpui::Entity<State>,
1064}
1065
1066impl ConfigurationView {
1067 fn authenticate(&mut self, cx: &mut Context<Self>) {
1068 self.state.update(cx, |state, cx| {
1069 state.authenticate(cx).detach_and_log_err(cx);
1070 });
1071 cx.notify();
1072 }
1073}
1074
1075impl Render for ConfigurationView {
1076 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1077 const ZED_PRICING_URL: &str = "https://zed.dev/pricing";
1078
1079 let is_connected = !self.state.read(cx).is_signed_out();
1080 let user_store = self.state.read(cx).user_store.read(cx);
1081 let plan = user_store.current_plan();
1082 let subscription_period = user_store.subscription_period();
1083 let eligible_for_trial = user_store.trial_started_at().is_none();
1084 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1085
1086 let is_pro = plan == Some(proto::Plan::ZedPro);
1087 let subscription_text = match (plan, subscription_period) {
1088 (Some(proto::Plan::ZedPro), Some(_)) => {
1089 "You have access to Zed's hosted LLMs through your Zed Pro subscription."
1090 }
1091 (Some(proto::Plan::ZedProTrial), Some(_)) => {
1092 "You have access to Zed's hosted LLMs through your Zed Pro trial."
1093 }
1094 (Some(proto::Plan::Free), Some(_)) => {
1095 "You have basic access to Zed's hosted LLMs through your Zed Free subscription."
1096 }
1097 _ => {
1098 if eligible_for_trial {
1099 "Subscribe for access to Zed's hosted LLMs. Start with a 14 day free trial."
1100 } else {
1101 "Subscribe for access to Zed's hosted LLMs."
1102 }
1103 }
1104 };
1105 let manage_subscription_buttons = if is_pro {
1106 h_flex().child(
1107 Button::new("manage_settings", "Manage Subscription")
1108 .style(ButtonStyle::Tinted(TintColor::Accent))
1109 .on_click(cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx)))),
1110 )
1111 } else {
1112 h_flex()
1113 .gap_2()
1114 .child(
1115 Button::new("learn_more", "Learn more")
1116 .style(ButtonStyle::Subtle)
1117 .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_PRICING_URL))),
1118 )
1119 .child(
1120 Button::new("upgrade", "Upgrade")
1121 .style(ButtonStyle::Subtle)
1122 .color(Color::Accent)
1123 .on_click(
1124 cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1125 ),
1126 )
1127 };
1128
1129 if is_connected {
1130 v_flex()
1131 .gap_3()
1132 .w_full()
1133 .children(render_accept_terms(
1134 self.state.clone(),
1135 LanguageModelProviderTosView::Configuration,
1136 cx,
1137 ))
1138 .when(has_accepted_terms, |this| {
1139 this.child(subscription_text)
1140 .child(manage_subscription_buttons)
1141 })
1142 } else {
1143 v_flex()
1144 .gap_2()
1145 .child(Label::new("Use Zed AI to access hosted language models."))
1146 .child(
1147 Button::new("sign_in", "Sign In")
1148 .icon_color(Color::Muted)
1149 .icon(IconName::Github)
1150 .icon_position(IconPosition::Start)
1151 .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1152 )
1153 }
1154 }
1155}