1use ai_onboarding::YoungAccountBanner;
2use anthropic::AnthropicModelMode;
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use client::{Client, ModelRequestUsage, UserStore, zed_urls};
6use cloud_llm_client::{
7 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
8 CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
9 EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan,
10 PlanV1, PlanV2, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
11 SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME, TOOL_USE_LIMIT_REACHED_HEADER_NAME,
12 ZED_VERSION_HEADER_NAME,
13};
14use feature_flags::{BillingV2FeatureFlag, FeatureFlagAppExt};
15use futures::{
16 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
17};
18use google_ai::GoogleModelMode;
19use gpui::{
20 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
21};
22use http_client::http::{HeaderMap, HeaderValue};
23use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
24use language_model::{
25 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
26 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
27 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
28 LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
29 LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError,
30 PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
31};
32use release_channel::AppVersion;
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize, de::DeserializeOwned};
35use settings::SettingsStore;
36use smol::io::{AsyncReadExt, BufReader};
37use std::pin::Pin;
38use std::str::FromStr as _;
39use std::sync::Arc;
40use std::time::Duration;
41use thiserror::Error;
42use ui::{TintColor, prelude::*};
43use util::{ResultExt as _, maybe};
44
45use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
46use crate::provider::google::{GoogleEventMapper, into_google};
47use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
48
49const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
50const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
51
52#[derive(Default, Clone, Debug, PartialEq)]
53pub struct ZedDotDevSettings {
54 pub available_models: Vec<AvailableModel>,
55}
56
57#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
58#[serde(rename_all = "lowercase")]
59pub enum AvailableProvider {
60 Anthropic,
61 OpenAi,
62 Google,
63}
64
65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
66pub struct AvailableModel {
67 /// The provider of the language model.
68 pub provider: AvailableProvider,
69 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
70 pub name: String,
71 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
72 pub display_name: Option<String>,
73 /// The size of the context window, indicating the maximum number of tokens the model can process.
74 pub max_tokens: usize,
75 /// The maximum number of output tokens allowed by the model.
76 pub max_output_tokens: Option<u64>,
77 /// The maximum number of completion tokens allowed by the model (o1-* only)
78 pub max_completion_tokens: Option<u64>,
79 /// Override this model with a different Anthropic model for tool calls.
80 pub tool_override: Option<String>,
81 /// Indicates whether this custom model supports caching.
82 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
83 /// The default temperature to use for this model.
84 pub default_temperature: Option<f32>,
85 /// Any extra beta headers to provide when using the model.
86 #[serde(default)]
87 pub extra_beta_headers: Vec<String>,
88 /// The model's mode (e.g. thinking)
89 pub mode: Option<ModelMode>,
90}
91
92#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
93#[serde(tag = "type", rename_all = "lowercase")]
94pub enum ModelMode {
95 #[default]
96 Default,
97 Thinking {
98 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
99 budget_tokens: Option<u32>,
100 },
101}
102
103impl From<ModelMode> for AnthropicModelMode {
104 fn from(value: ModelMode) -> Self {
105 match value {
106 ModelMode::Default => AnthropicModelMode::Default,
107 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
108 }
109 }
110}
111
112pub struct CloudLanguageModelProvider {
113 client: Arc<Client>,
114 state: gpui::Entity<State>,
115 _maintain_client_status: Task<()>,
116}
117
118pub struct State {
119 client: Arc<Client>,
120 llm_api_token: LlmApiToken,
121 user_store: Entity<UserStore>,
122 status: client::Status,
123 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
124 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
125 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
126 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
127 _fetch_models_task: Task<()>,
128 _settings_subscription: Subscription,
129 _llm_token_subscription: Subscription,
130}
131
132impl State {
133 fn new(
134 client: Arc<Client>,
135 user_store: Entity<UserStore>,
136 status: client::Status,
137 cx: &mut Context<Self>,
138 ) -> Self {
139 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
140 let mut current_user = user_store.read(cx).watch_current_user();
141 Self {
142 client: client.clone(),
143 llm_api_token: LlmApiToken::default(),
144 user_store,
145 status,
146 models: Vec::new(),
147 default_model: None,
148 default_fast_model: None,
149 recommended_models: Vec::new(),
150 _fetch_models_task: cx.spawn(async move |this, cx| {
151 maybe!(async move {
152 let (client, llm_api_token) = this
153 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
154
155 while current_user.borrow().is_none() {
156 current_user.next().await;
157 }
158
159 let response =
160 Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
161 this.update(cx, |this, cx| this.update_models(response, cx))?;
162 anyhow::Ok(())
163 })
164 .await
165 .context("failed to fetch Zed models")
166 .log_err();
167 }),
168 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
169 cx.notify();
170 }),
171 _llm_token_subscription: cx.subscribe(
172 &refresh_llm_token_listener,
173 move |this, _listener, _event, cx| {
174 let client = this.client.clone();
175 let llm_api_token = this.llm_api_token.clone();
176 cx.spawn(async move |this, cx| {
177 llm_api_token.refresh(&client).await?;
178 let response = Self::fetch_models(client, llm_api_token).await?;
179 this.update(cx, |this, cx| {
180 this.update_models(response, cx);
181 })
182 })
183 .detach_and_log_err(cx);
184 },
185 ),
186 }
187 }
188
189 fn is_signed_out(&self, cx: &App) -> bool {
190 self.user_store.read(cx).current_user().is_none()
191 }
192
193 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
194 let client = self.client.clone();
195 cx.spawn(async move |state, cx| {
196 client.sign_in_with_optional_connect(true, cx).await?;
197 state.update(cx, |_, cx| cx.notify())
198 })
199 }
200 fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
201 let mut models = Vec::new();
202
203 for model in response.models {
204 models.push(Arc::new(model.clone()));
205
206 // Right now we represent thinking variants of models as separate models on the client,
207 // so we need to insert variants for any model that supports thinking.
208 if model.supports_thinking {
209 models.push(Arc::new(cloud_llm_client::LanguageModel {
210 id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
211 display_name: format!("{} Thinking", model.display_name),
212 ..model
213 }));
214 }
215 }
216
217 self.default_model = models
218 .iter()
219 .find(|model| {
220 response
221 .default_model
222 .as_ref()
223 .is_some_and(|default_model_id| &model.id == default_model_id)
224 })
225 .cloned();
226 self.default_fast_model = models
227 .iter()
228 .find(|model| {
229 response
230 .default_fast_model
231 .as_ref()
232 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
233 })
234 .cloned();
235 self.recommended_models = response
236 .recommended_models
237 .iter()
238 .filter_map(|id| models.iter().find(|model| &model.id == id))
239 .cloned()
240 .collect();
241 self.models = models;
242 cx.notify();
243 }
244
245 async fn fetch_models(
246 client: Arc<Client>,
247 llm_api_token: LlmApiToken,
248 ) -> Result<ListModelsResponse> {
249 let http_client = &client.http_client();
250 let token = llm_api_token.acquire(&client).await?;
251
252 let request = http_client::Request::builder()
253 .method(Method::GET)
254 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
255 .header("Authorization", format!("Bearer {token}"))
256 .body(AsyncBody::empty())?;
257 let mut response = http_client
258 .send(request)
259 .await
260 .context("failed to send list models request")?;
261
262 if response.status().is_success() {
263 let mut body = String::new();
264 response.body_mut().read_to_string(&mut body).await?;
265 Ok(serde_json::from_str(&body)?)
266 } else {
267 let mut body = String::new();
268 response.body_mut().read_to_string(&mut body).await?;
269 anyhow::bail!(
270 "error listing models.\nStatus: {:?}\nBody: {body}",
271 response.status(),
272 );
273 }
274 }
275}
276
277impl CloudLanguageModelProvider {
278 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
279 let mut status_rx = client.status();
280 let status = *status_rx.borrow();
281
282 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
283
284 let state_ref = state.downgrade();
285 let maintain_client_status = cx.spawn(async move |cx| {
286 while let Some(status) = status_rx.next().await {
287 if let Some(this) = state_ref.upgrade() {
288 _ = this.update(cx, |this, cx| {
289 if this.status != status {
290 this.status = status;
291 cx.notify();
292 }
293 });
294 } else {
295 break;
296 }
297 }
298 });
299
300 Self {
301 client,
302 state,
303 _maintain_client_status: maintain_client_status,
304 }
305 }
306
307 fn create_language_model(
308 &self,
309 model: Arc<cloud_llm_client::LanguageModel>,
310 llm_api_token: LlmApiToken,
311 ) -> Arc<dyn LanguageModel> {
312 Arc::new(CloudLanguageModel {
313 id: LanguageModelId(SharedString::from(model.id.0.clone())),
314 model,
315 llm_api_token,
316 client: self.client.clone(),
317 request_limiter: RateLimiter::new(4),
318 })
319 }
320}
321
322impl LanguageModelProviderState for CloudLanguageModelProvider {
323 type ObservableEntity = State;
324
325 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
326 Some(self.state.clone())
327 }
328}
329
330impl LanguageModelProvider for CloudLanguageModelProvider {
331 fn id(&self) -> LanguageModelProviderId {
332 PROVIDER_ID
333 }
334
335 fn name(&self) -> LanguageModelProviderName {
336 PROVIDER_NAME
337 }
338
339 fn icon(&self) -> IconName {
340 IconName::AiZed
341 }
342
343 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
344 let default_model = self.state.read(cx).default_model.clone()?;
345 let llm_api_token = self.state.read(cx).llm_api_token.clone();
346 Some(self.create_language_model(default_model, llm_api_token))
347 }
348
349 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
350 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
351 let llm_api_token = self.state.read(cx).llm_api_token.clone();
352 Some(self.create_language_model(default_fast_model, llm_api_token))
353 }
354
355 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
356 let llm_api_token = self.state.read(cx).llm_api_token.clone();
357 self.state
358 .read(cx)
359 .recommended_models
360 .iter()
361 .cloned()
362 .map(|model| self.create_language_model(model, llm_api_token.clone()))
363 .collect()
364 }
365
366 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
367 let llm_api_token = self.state.read(cx).llm_api_token.clone();
368 self.state
369 .read(cx)
370 .models
371 .iter()
372 .cloned()
373 .map(|model| self.create_language_model(model, llm_api_token.clone()))
374 .collect()
375 }
376
377 fn is_authenticated(&self, cx: &App) -> bool {
378 let state = self.state.read(cx);
379 !state.is_signed_out(cx)
380 }
381
382 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
383 Task::ready(Ok(()))
384 }
385
386 fn configuration_view(
387 &self,
388 _target_agent: language_model::ConfigurationViewTargetAgent,
389 _: &mut Window,
390 cx: &mut App,
391 ) -> AnyView {
392 cx.new(|_| ConfigurationView::new(self.state.clone()))
393 .into()
394 }
395
396 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
397 Task::ready(Ok(()))
398 }
399}
400
401pub struct CloudLanguageModel {
402 id: LanguageModelId,
403 model: Arc<cloud_llm_client::LanguageModel>,
404 llm_api_token: LlmApiToken,
405 client: Arc<Client>,
406 request_limiter: RateLimiter,
407}
408
409struct PerformLlmCompletionResponse {
410 response: Response<AsyncBody>,
411 usage: Option<ModelRequestUsage>,
412 tool_use_limit_reached: bool,
413 includes_status_messages: bool,
414}
415
416impl CloudLanguageModel {
417 async fn perform_llm_completion(
418 client: Arc<Client>,
419 llm_api_token: LlmApiToken,
420 app_version: Option<SemanticVersion>,
421 body: CompletionBody,
422 ) -> Result<PerformLlmCompletionResponse> {
423 let http_client = &client.http_client();
424
425 let mut token = llm_api_token.acquire(&client).await?;
426 let mut refreshed_token = false;
427
428 loop {
429 let request_builder = http_client::Request::builder()
430 .method(Method::POST)
431 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
432 let request_builder = if let Some(app_version) = app_version {
433 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
434 } else {
435 request_builder
436 };
437
438 let request = request_builder
439 .header("Content-Type", "application/json")
440 .header("Authorization", format!("Bearer {token}"))
441 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
442 .body(serde_json::to_string(&body)?.into())?;
443 let mut response = http_client.send(request).await?;
444 let status = response.status();
445 if status.is_success() {
446 let includes_status_messages = response
447 .headers()
448 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
449 .is_some();
450
451 let tool_use_limit_reached = response
452 .headers()
453 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
454 .is_some();
455
456 let usage = if includes_status_messages {
457 None
458 } else {
459 ModelRequestUsage::from_headers(response.headers()).ok()
460 };
461
462 return Ok(PerformLlmCompletionResponse {
463 response,
464 usage,
465 includes_status_messages,
466 tool_use_limit_reached,
467 });
468 }
469
470 if !refreshed_token
471 && response
472 .headers()
473 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
474 .is_some()
475 {
476 token = llm_api_token.refresh(&client).await?;
477 refreshed_token = true;
478 continue;
479 }
480
481 if status == StatusCode::FORBIDDEN
482 && response
483 .headers()
484 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
485 .is_some()
486 {
487 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
488 .headers()
489 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
490 .and_then(|resource| resource.to_str().ok())
491 && let Some(plan) = response
492 .headers()
493 .get(CURRENT_PLAN_HEADER_NAME)
494 .and_then(|plan| plan.to_str().ok())
495 .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
496 .map(Plan::V1)
497 {
498 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
499 }
500 } else if status == StatusCode::PAYMENT_REQUIRED {
501 return Err(anyhow!(PaymentRequiredError));
502 }
503
504 let mut body = String::new();
505 let headers = response.headers().clone();
506 response.body_mut().read_to_string(&mut body).await?;
507 return Err(anyhow!(ApiError {
508 status,
509 body,
510 headers
511 }));
512 }
513 }
514}
515
516#[derive(Debug, Error)]
517#[error("cloud language model request failed with status {status}: {body}")]
518struct ApiError {
519 status: StatusCode,
520 body: String,
521 headers: HeaderMap<HeaderValue>,
522}
523
524/// Represents error responses from Zed's cloud API.
525///
526/// Example JSON for an upstream HTTP error:
527/// ```json
528/// {
529/// "code": "upstream_http_error",
530/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
531/// "upstream_status": 503
532/// }
533/// ```
534#[derive(Debug, serde::Deserialize)]
535struct CloudApiError {
536 code: String,
537 message: String,
538 #[serde(default)]
539 #[serde(deserialize_with = "deserialize_optional_status_code")]
540 upstream_status: Option<StatusCode>,
541 #[serde(default)]
542 retry_after: Option<f64>,
543}
544
545fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
546where
547 D: serde::Deserializer<'de>,
548{
549 let opt: Option<u16> = Option::deserialize(deserializer)?;
550 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
551}
552
553impl From<ApiError> for LanguageModelCompletionError {
554 fn from(error: ApiError) -> Self {
555 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
556 if cloud_error.code.starts_with("upstream_http_") {
557 let status = if let Some(status) = cloud_error.upstream_status {
558 status
559 } else if cloud_error.code.ends_with("_error") {
560 error.status
561 } else {
562 // If there's a status code in the code string (e.g. "upstream_http_429")
563 // then use that; otherwise, see if the JSON contains a status code.
564 cloud_error
565 .code
566 .strip_prefix("upstream_http_")
567 .and_then(|code_str| code_str.parse::<u16>().ok())
568 .and_then(|code| StatusCode::from_u16(code).ok())
569 .unwrap_or(error.status)
570 };
571
572 return LanguageModelCompletionError::UpstreamProviderError {
573 message: cloud_error.message,
574 status,
575 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
576 };
577 }
578
579 return LanguageModelCompletionError::from_http_status(
580 PROVIDER_NAME,
581 error.status,
582 cloud_error.message,
583 None,
584 );
585 }
586
587 let retry_after = None;
588 LanguageModelCompletionError::from_http_status(
589 PROVIDER_NAME,
590 error.status,
591 error.body,
592 retry_after,
593 )
594 }
595}
596
597impl LanguageModel for CloudLanguageModel {
598 fn id(&self) -> LanguageModelId {
599 self.id.clone()
600 }
601
602 fn name(&self) -> LanguageModelName {
603 LanguageModelName::from(self.model.display_name.clone())
604 }
605
606 fn provider_id(&self) -> LanguageModelProviderId {
607 PROVIDER_ID
608 }
609
610 fn provider_name(&self) -> LanguageModelProviderName {
611 PROVIDER_NAME
612 }
613
614 fn upstream_provider_id(&self) -> LanguageModelProviderId {
615 use cloud_llm_client::LanguageModelProvider::*;
616 match self.model.provider {
617 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
618 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
619 Google => language_model::GOOGLE_PROVIDER_ID,
620 }
621 }
622
623 fn upstream_provider_name(&self) -> LanguageModelProviderName {
624 use cloud_llm_client::LanguageModelProvider::*;
625 match self.model.provider {
626 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
627 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
628 Google => language_model::GOOGLE_PROVIDER_NAME,
629 }
630 }
631
632 fn supports_tools(&self) -> bool {
633 self.model.supports_tools
634 }
635
636 fn supports_images(&self) -> bool {
637 self.model.supports_images
638 }
639
640 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
641 match choice {
642 LanguageModelToolChoice::Auto
643 | LanguageModelToolChoice::Any
644 | LanguageModelToolChoice::None => true,
645 }
646 }
647
648 fn supports_burn_mode(&self) -> bool {
649 self.model.supports_max_mode
650 }
651
652 fn telemetry_id(&self) -> String {
653 format!("zed.dev/{}", self.model.id)
654 }
655
656 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
657 match self.model.provider {
658 cloud_llm_client::LanguageModelProvider::Anthropic
659 | cloud_llm_client::LanguageModelProvider::OpenAi => {
660 LanguageModelToolSchemaFormat::JsonSchema
661 }
662 cloud_llm_client::LanguageModelProvider::Google => {
663 LanguageModelToolSchemaFormat::JsonSchemaSubset
664 }
665 }
666 }
667
668 fn max_token_count(&self) -> u64 {
669 self.model.max_token_count as u64
670 }
671
672 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
673 self.model
674 .max_token_count_in_max_mode
675 .filter(|_| self.model.supports_max_mode)
676 .map(|max_token_count| max_token_count as u64)
677 }
678
679 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
680 match &self.model.provider {
681 cloud_llm_client::LanguageModelProvider::Anthropic => {
682 Some(LanguageModelCacheConfiguration {
683 min_total_token: 2_048,
684 should_speculate: true,
685 max_cache_anchors: 4,
686 })
687 }
688 cloud_llm_client::LanguageModelProvider::OpenAi
689 | cloud_llm_client::LanguageModelProvider::Google => None,
690 }
691 }
692
693 fn count_tokens(
694 &self,
695 request: LanguageModelRequest,
696 cx: &App,
697 ) -> BoxFuture<'static, Result<u64>> {
698 match self.model.provider {
699 cloud_llm_client::LanguageModelProvider::Anthropic => {
700 count_anthropic_tokens(request, cx)
701 }
702 cloud_llm_client::LanguageModelProvider::OpenAi => {
703 let model = match open_ai::Model::from_id(&self.model.id.0) {
704 Ok(model) => model,
705 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
706 };
707 count_open_ai_tokens(request, model, cx)
708 }
709 cloud_llm_client::LanguageModelProvider::Google => {
710 let client = self.client.clone();
711 let llm_api_token = self.llm_api_token.clone();
712 let model_id = self.model.id.to_string();
713 let generate_content_request =
714 into_google(request, model_id.clone(), GoogleModelMode::Default);
715 async move {
716 let http_client = &client.http_client();
717 let token = llm_api_token.acquire(&client).await?;
718
719 let request_body = CountTokensBody {
720 provider: cloud_llm_client::LanguageModelProvider::Google,
721 model: model_id,
722 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
723 generate_content_request,
724 })?,
725 };
726 let request = http_client::Request::builder()
727 .method(Method::POST)
728 .uri(
729 http_client
730 .build_zed_llm_url("/count_tokens", &[])?
731 .as_ref(),
732 )
733 .header("Content-Type", "application/json")
734 .header("Authorization", format!("Bearer {token}"))
735 .body(serde_json::to_string(&request_body)?.into())?;
736 let mut response = http_client.send(request).await?;
737 let status = response.status();
738 let headers = response.headers().clone();
739 let mut response_body = String::new();
740 response
741 .body_mut()
742 .read_to_string(&mut response_body)
743 .await?;
744
745 if status.is_success() {
746 let response_body: CountTokensResponse =
747 serde_json::from_str(&response_body)?;
748
749 Ok(response_body.tokens as u64)
750 } else {
751 Err(anyhow!(ApiError {
752 status,
753 body: response_body,
754 headers
755 }))
756 }
757 }
758 .boxed()
759 }
760 }
761 }
762
763 fn stream_completion(
764 &self,
765 request: LanguageModelRequest,
766 cx: &AsyncApp,
767 ) -> BoxFuture<
768 'static,
769 Result<
770 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
771 LanguageModelCompletionError,
772 >,
773 > {
774 let thread_id = request.thread_id.clone();
775 let prompt_id = request.prompt_id.clone();
776 let intent = request.intent;
777 let mode = request.mode;
778 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
779 let thinking_allowed = request.thinking_allowed;
780 match self.model.provider {
781 cloud_llm_client::LanguageModelProvider::Anthropic => {
782 let request = into_anthropic(
783 request,
784 self.model.id.to_string(),
785 1.0,
786 self.model.max_output_tokens as u64,
787 if thinking_allowed && self.model.id.0.ends_with("-thinking") {
788 AnthropicModelMode::Thinking {
789 budget_tokens: Some(4_096),
790 }
791 } else {
792 AnthropicModelMode::Default
793 },
794 );
795 let client = self.client.clone();
796 let llm_api_token = self.llm_api_token.clone();
797 let future = self.request_limiter.stream(async move {
798 let PerformLlmCompletionResponse {
799 response,
800 usage,
801 includes_status_messages,
802 tool_use_limit_reached,
803 } = Self::perform_llm_completion(
804 client.clone(),
805 llm_api_token,
806 app_version,
807 CompletionBody {
808 thread_id,
809 prompt_id,
810 intent,
811 mode,
812 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
813 model: request.model.clone(),
814 provider_request: serde_json::to_value(&request)
815 .map_err(|e| anyhow!(e))?,
816 },
817 )
818 .await
819 .map_err(|err| match err.downcast::<ApiError>() {
820 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
821 Err(err) => anyhow!(err),
822 })?;
823
824 let mut mapper = AnthropicEventMapper::new();
825 Ok(map_cloud_completion_events(
826 Box::pin(
827 response_lines(response, includes_status_messages)
828 .chain(usage_updated_event(usage))
829 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
830 ),
831 move |event| mapper.map_event(event),
832 ))
833 });
834 async move { Ok(future.await?.boxed()) }.boxed()
835 }
836 cloud_llm_client::LanguageModelProvider::OpenAi => {
837 let client = self.client.clone();
838 let model = match open_ai::Model::from_id(&self.model.id.0) {
839 Ok(model) => model,
840 Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
841 };
842 let request = into_open_ai(
843 request,
844 model.id(),
845 model.supports_parallel_tool_calls(),
846 model.supports_prompt_cache_key(),
847 None,
848 None,
849 );
850 let llm_api_token = self.llm_api_token.clone();
851 let future = self.request_limiter.stream(async move {
852 let PerformLlmCompletionResponse {
853 response,
854 usage,
855 includes_status_messages,
856 tool_use_limit_reached,
857 } = Self::perform_llm_completion(
858 client.clone(),
859 llm_api_token,
860 app_version,
861 CompletionBody {
862 thread_id,
863 prompt_id,
864 intent,
865 mode,
866 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
867 model: request.model.clone(),
868 provider_request: serde_json::to_value(&request)
869 .map_err(|e| anyhow!(e))?,
870 },
871 )
872 .await?;
873
874 let mut mapper = OpenAiEventMapper::new();
875 Ok(map_cloud_completion_events(
876 Box::pin(
877 response_lines(response, includes_status_messages)
878 .chain(usage_updated_event(usage))
879 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
880 ),
881 move |event| mapper.map_event(event),
882 ))
883 });
884 async move { Ok(future.await?.boxed()) }.boxed()
885 }
886 cloud_llm_client::LanguageModelProvider::Google => {
887 let client = self.client.clone();
888 let request =
889 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
890 let llm_api_token = self.llm_api_token.clone();
891 let future = self.request_limiter.stream(async move {
892 let PerformLlmCompletionResponse {
893 response,
894 usage,
895 includes_status_messages,
896 tool_use_limit_reached,
897 } = Self::perform_llm_completion(
898 client.clone(),
899 llm_api_token,
900 app_version,
901 CompletionBody {
902 thread_id,
903 prompt_id,
904 intent,
905 mode,
906 provider: cloud_llm_client::LanguageModelProvider::Google,
907 model: request.model.model_id.clone(),
908 provider_request: serde_json::to_value(&request)
909 .map_err(|e| anyhow!(e))?,
910 },
911 )
912 .await?;
913
914 let mut mapper = GoogleEventMapper::new();
915 Ok(map_cloud_completion_events(
916 Box::pin(
917 response_lines(response, includes_status_messages)
918 .chain(usage_updated_event(usage))
919 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
920 ),
921 move |event| mapper.map_event(event),
922 ))
923 });
924 async move { Ok(future.await?.boxed()) }.boxed()
925 }
926 }
927 }
928}
929
930fn map_cloud_completion_events<T, F>(
931 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
932 mut map_callback: F,
933) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
934where
935 T: DeserializeOwned + 'static,
936 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
937 + Send
938 + 'static,
939{
940 stream
941 .flat_map(move |event| {
942 futures::stream::iter(match event {
943 Err(error) => {
944 vec![Err(LanguageModelCompletionError::from(error))]
945 }
946 Ok(CompletionEvent::Status(event)) => {
947 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
948 }
949 Ok(CompletionEvent::Event(event)) => map_callback(event),
950 })
951 })
952 .boxed()
953}
954
955fn usage_updated_event<T>(
956 usage: Option<ModelRequestUsage>,
957) -> impl Stream<Item = Result<CompletionEvent<T>>> {
958 futures::stream::iter(usage.map(|usage| {
959 Ok(CompletionEvent::Status(
960 CompletionRequestStatus::UsageUpdated {
961 amount: usage.amount as usize,
962 limit: usage.limit,
963 },
964 ))
965 }))
966}
967
968fn tool_use_limit_reached_event<T>(
969 tool_use_limit_reached: bool,
970) -> impl Stream<Item = Result<CompletionEvent<T>>> {
971 futures::stream::iter(tool_use_limit_reached.then(|| {
972 Ok(CompletionEvent::Status(
973 CompletionRequestStatus::ToolUseLimitReached,
974 ))
975 }))
976}
977
978fn response_lines<T: DeserializeOwned>(
979 response: Response<AsyncBody>,
980 includes_status_messages: bool,
981) -> impl Stream<Item = Result<CompletionEvent<T>>> {
982 futures::stream::try_unfold(
983 (String::new(), BufReader::new(response.into_body())),
984 move |(mut line, mut body)| async move {
985 match body.read_line(&mut line).await {
986 Ok(0) => Ok(None),
987 Ok(_) => {
988 let event = if includes_status_messages {
989 serde_json::from_str::<CompletionEvent<T>>(&line)?
990 } else {
991 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
992 };
993
994 line.clear();
995 Ok(Some((event, (line, body))))
996 }
997 Err(e) => Err(e.into()),
998 }
999 },
1000 )
1001}
1002
1003#[derive(IntoElement, RegisterComponent)]
1004struct ZedAiConfiguration {
1005 is_connected: bool,
1006 plan: Option<Plan>,
1007 subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1008 eligible_for_trial: bool,
1009 account_too_young: bool,
1010 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1011}
1012
1013impl RenderOnce for ZedAiConfiguration {
1014 fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement {
1015 let is_pro = self.plan.is_some_and(|plan| {
1016 matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
1017 });
1018 let is_free_v2 = self
1019 .plan
1020 .is_some_and(|plan| plan == Plan::V2(PlanV2::ZedFree));
1021 let subscription_text = match (self.plan, self.subscription_period) {
1022 (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
1023 "You have access to Zed's hosted models through your Pro subscription."
1024 }
1025 (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
1026 "You have access to Zed's hosted models through your Pro trial."
1027 }
1028 (Some(Plan::V1(PlanV1::ZedFree)), Some(_)) => {
1029 "You have basic access to Zed's hosted models through the Free plan."
1030 }
1031 (Some(Plan::V2(PlanV2::ZedFree)), Some(_)) => {
1032 if self.eligible_for_trial {
1033 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1034 } else {
1035 "Subscribe for access to Zed's hosted models."
1036 }
1037 }
1038 _ => {
1039 if self.eligible_for_trial {
1040 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1041 } else {
1042 "Subscribe for access to Zed's hosted models."
1043 }
1044 }
1045 };
1046
1047 let manage_subscription_buttons = if is_pro {
1048 Button::new("manage_settings", "Manage Subscription")
1049 .full_width()
1050 .style(ButtonStyle::Tinted(TintColor::Accent))
1051 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1052 .into_any_element()
1053 } else if self.plan.is_none() || self.eligible_for_trial {
1054 Button::new("start_trial", "Start 14-day Free Pro Trial")
1055 .full_width()
1056 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1057 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1058 .into_any_element()
1059 } else {
1060 Button::new("upgrade", "Upgrade to Pro")
1061 .full_width()
1062 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1063 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1064 .into_any_element()
1065 };
1066
1067 if !self.is_connected {
1068 return v_flex()
1069 .gap_2()
1070 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1071 .child(
1072 Button::new("sign_in", "Sign In to use Zed AI")
1073 .icon_color(Color::Muted)
1074 .icon(IconName::Github)
1075 .icon_size(IconSize::Small)
1076 .icon_position(IconPosition::Start)
1077 .full_width()
1078 .on_click({
1079 let callback = self.sign_in_callback.clone();
1080 move |_, window, cx| (callback)(window, cx)
1081 }),
1082 );
1083 }
1084
1085 v_flex().gap_2().w_full().map(|this| {
1086 if self.account_too_young {
1087 this.child(YoungAccountBanner::new(
1088 is_free_v2 || cx.has_flag::<BillingV2FeatureFlag>(),
1089 ))
1090 .child(
1091 Button::new("upgrade", "Upgrade to Pro")
1092 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1093 .full_width()
1094 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1095 )
1096 } else {
1097 this.text_sm()
1098 .child(subscription_text)
1099 .child(manage_subscription_buttons)
1100 }
1101 })
1102 }
1103}
1104
1105struct ConfigurationView {
1106 state: Entity<State>,
1107 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1108}
1109
1110impl ConfigurationView {
1111 fn new(state: Entity<State>) -> Self {
1112 let sign_in_callback = Arc::new({
1113 let state = state.clone();
1114 move |_window: &mut Window, cx: &mut App| {
1115 state.update(cx, |state, cx| {
1116 state.authenticate(cx).detach_and_log_err(cx);
1117 });
1118 }
1119 });
1120
1121 Self {
1122 state,
1123 sign_in_callback,
1124 }
1125 }
1126}
1127
1128impl Render for ConfigurationView {
1129 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1130 let state = self.state.read(cx);
1131 let user_store = state.user_store.read(cx);
1132
1133 ZedAiConfiguration {
1134 is_connected: !state.is_signed_out(cx),
1135 plan: user_store.plan(),
1136 subscription_period: user_store.subscription_period(),
1137 eligible_for_trial: user_store.trial_started_at().is_none(),
1138 account_too_young: user_store.account_too_young(),
1139 sign_in_callback: self.sign_in_callback.clone(),
1140 }
1141 }
1142}
1143
1144impl Component for ZedAiConfiguration {
1145 fn name() -> &'static str {
1146 "AI Configuration Content"
1147 }
1148
1149 fn sort_name() -> &'static str {
1150 "AI Configuration Content"
1151 }
1152
1153 fn scope() -> ComponentScope {
1154 ComponentScope::Onboarding
1155 }
1156
1157 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1158 fn configuration(
1159 is_connected: bool,
1160 plan: Option<Plan>,
1161 eligible_for_trial: bool,
1162 account_too_young: bool,
1163 ) -> AnyElement {
1164 ZedAiConfiguration {
1165 is_connected,
1166 plan,
1167 subscription_period: plan
1168 .is_some()
1169 .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1170 eligible_for_trial,
1171 account_too_young,
1172 sign_in_callback: Arc::new(|_, _| {}),
1173 }
1174 .into_any_element()
1175 }
1176
1177 Some(
1178 v_flex()
1179 .p_4()
1180 .gap_4()
1181 .children(vec![
1182 single_example("Not connected", configuration(false, None, false, false)),
1183 single_example(
1184 "Accept Terms of Service",
1185 configuration(true, None, true, false),
1186 ),
1187 single_example(
1188 "No Plan - Not eligible for trial",
1189 configuration(true, None, false, false),
1190 ),
1191 single_example(
1192 "No Plan - Eligible for trial",
1193 configuration(true, None, true, false),
1194 ),
1195 single_example(
1196 "Free Plan",
1197 configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1198 ),
1199 single_example(
1200 "Zed Pro Trial Plan",
1201 configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1202 ),
1203 single_example(
1204 "Zed Pro Plan",
1205 configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1206 ),
1207 ])
1208 .into_any_element(),
1209 )
1210 }
1211}
1212
1213#[cfg(test)]
1214mod tests {
1215 use super::*;
1216 use http_client::http::{HeaderMap, StatusCode};
1217 use language_model::LanguageModelCompletionError;
1218
1219 #[test]
1220 fn test_api_error_conversion_with_upstream_http_error() {
1221 // upstream_http_error with 503 status should become ServerOverloaded
1222 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1223
1224 let api_error = ApiError {
1225 status: StatusCode::INTERNAL_SERVER_ERROR,
1226 body: error_body.to_string(),
1227 headers: HeaderMap::new(),
1228 };
1229
1230 let completion_error: LanguageModelCompletionError = api_error.into();
1231
1232 match completion_error {
1233 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1234 assert_eq!(
1235 message,
1236 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1237 );
1238 }
1239 _ => panic!(
1240 "Expected UpstreamProviderError for upstream 503, got: {:?}",
1241 completion_error
1242 ),
1243 }
1244
1245 // upstream_http_error with 500 status should become ApiInternalServerError
1246 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1247
1248 let api_error = ApiError {
1249 status: StatusCode::INTERNAL_SERVER_ERROR,
1250 body: error_body.to_string(),
1251 headers: HeaderMap::new(),
1252 };
1253
1254 let completion_error: LanguageModelCompletionError = api_error.into();
1255
1256 match completion_error {
1257 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1258 assert_eq!(
1259 message,
1260 "Received an error from the OpenAI API: internal server error"
1261 );
1262 }
1263 _ => panic!(
1264 "Expected UpstreamProviderError for upstream 500, got: {:?}",
1265 completion_error
1266 ),
1267 }
1268
1269 // upstream_http_error with 429 status should become RateLimitExceeded
1270 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1271
1272 let api_error = ApiError {
1273 status: StatusCode::INTERNAL_SERVER_ERROR,
1274 body: error_body.to_string(),
1275 headers: HeaderMap::new(),
1276 };
1277
1278 let completion_error: LanguageModelCompletionError = api_error.into();
1279
1280 match completion_error {
1281 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1282 assert_eq!(
1283 message,
1284 "Received an error from the Google API: rate limit exceeded"
1285 );
1286 }
1287 _ => panic!(
1288 "Expected UpstreamProviderError for upstream 429, got: {:?}",
1289 completion_error
1290 ),
1291 }
1292
1293 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1294 let error_body = "Regular internal server error";
1295
1296 let api_error = ApiError {
1297 status: StatusCode::INTERNAL_SERVER_ERROR,
1298 body: error_body.to_string(),
1299 headers: HeaderMap::new(),
1300 };
1301
1302 let completion_error: LanguageModelCompletionError = api_error.into();
1303
1304 match completion_error {
1305 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1306 assert_eq!(provider, PROVIDER_NAME);
1307 assert_eq!(message, "Regular internal server error");
1308 }
1309 _ => panic!(
1310 "Expected ApiInternalServerError for regular 500, got: {:?}",
1311 completion_error
1312 ),
1313 }
1314
1315 // upstream_http_429 format should be converted to UpstreamProviderError
1316 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1317
1318 let api_error = ApiError {
1319 status: StatusCode::INTERNAL_SERVER_ERROR,
1320 body: error_body.to_string(),
1321 headers: HeaderMap::new(),
1322 };
1323
1324 let completion_error: LanguageModelCompletionError = api_error.into();
1325
1326 match completion_error {
1327 LanguageModelCompletionError::UpstreamProviderError {
1328 message,
1329 status,
1330 retry_after,
1331 } => {
1332 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1333 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1334 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1335 }
1336 _ => panic!(
1337 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1338 completion_error
1339 ),
1340 }
1341
1342 // Invalid JSON in error body should fall back to regular error handling
1343 let error_body = "Not JSON at all";
1344
1345 let api_error = ApiError {
1346 status: StatusCode::INTERNAL_SERVER_ERROR,
1347 body: error_body.to_string(),
1348 headers: HeaderMap::new(),
1349 };
1350
1351 let completion_error: LanguageModelCompletionError = api_error.into();
1352
1353 match completion_error {
1354 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1355 assert_eq!(provider, PROVIDER_NAME);
1356 }
1357 _ => panic!(
1358 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1359 completion_error
1360 ),
1361 }
1362 }
1363}