1use ai_onboarding::YoungAccountBanner;
2use anthropic::AnthropicModelMode;
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use client::{Client, ModelRequestUsage, UserStore, zed_urls};
6use cloud_llm_client::{
7 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME,
8 CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
9 CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
10 MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, PlanV1, PlanV2,
11 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
12 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
13};
14use feature_flags::{FeatureFlagAppExt as _, OpenAiResponsesApiFeatureFlag};
15use futures::{
16 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
17};
18use google_ai::GoogleModelMode;
19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
20use http_client::http::{HeaderMap, HeaderValue};
21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
22use language_model::{
23 AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
24 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
25 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
26 LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
27 LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError,
28 PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
29};
30use release_channel::AppVersion;
31use schemars::JsonSchema;
32use semver::Version;
33use serde::{Deserialize, Serialize, de::DeserializeOwned};
34use settings::SettingsStore;
35pub use settings::ZedDotDevAvailableModel as AvailableModel;
36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
37use smol::io::{AsyncReadExt, BufReader};
38use std::pin::Pin;
39use std::str::FromStr as _;
40use std::sync::Arc;
41use std::time::Duration;
42use thiserror::Error;
43use ui::{TintColor, prelude::*};
44use util::{ResultExt as _, maybe};
45
46use crate::provider::anthropic::{
47 AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
48};
49use crate::provider::google::{GoogleEventMapper, into_google};
50use crate::provider::open_ai::{
51 OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
52 into_open_ai_response,
53};
54use crate::provider::x_ai::count_xai_tokens;
55
56const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
57const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
58
59#[derive(Default, Clone, Debug, PartialEq)]
60pub struct ZedDotDevSettings {
61 pub available_models: Vec<AvailableModel>,
62}
63#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
64#[serde(tag = "type", rename_all = "lowercase")]
65pub enum ModelMode {
66 #[default]
67 Default,
68 Thinking {
69 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
70 budget_tokens: Option<u32>,
71 },
72}
73
74impl From<ModelMode> for AnthropicModelMode {
75 fn from(value: ModelMode) -> Self {
76 match value {
77 ModelMode::Default => AnthropicModelMode::Default,
78 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
79 }
80 }
81}
82
83pub struct CloudLanguageModelProvider {
84 client: Arc<Client>,
85 state: Entity<State>,
86 _maintain_client_status: Task<()>,
87}
88
89pub struct State {
90 client: Arc<Client>,
91 llm_api_token: LlmApiToken,
92 user_store: Entity<UserStore>,
93 status: client::Status,
94 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
95 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
96 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
97 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
98 _fetch_models_task: Task<()>,
99 _settings_subscription: Subscription,
100 _llm_token_subscription: Subscription,
101}
102
103impl State {
104 fn new(
105 client: Arc<Client>,
106 user_store: Entity<UserStore>,
107 status: client::Status,
108 cx: &mut Context<Self>,
109 ) -> Self {
110 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
111 let mut current_user = user_store.read(cx).watch_current_user();
112 Self {
113 client: client.clone(),
114 llm_api_token: LlmApiToken::default(),
115 user_store,
116 status,
117 models: Vec::new(),
118 default_model: None,
119 default_fast_model: None,
120 recommended_models: Vec::new(),
121 _fetch_models_task: cx.spawn(async move |this, cx| {
122 maybe!(async move {
123 let (client, llm_api_token) = this
124 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
125
126 while current_user.borrow().is_none() {
127 current_user.next().await;
128 }
129
130 let response =
131 Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
132 this.update(cx, |this, cx| this.update_models(response, cx))?;
133 anyhow::Ok(())
134 })
135 .await
136 .context("failed to fetch Zed models")
137 .log_err();
138 }),
139 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
140 cx.notify();
141 }),
142 _llm_token_subscription: cx.subscribe(
143 &refresh_llm_token_listener,
144 move |this, _listener, _event, cx| {
145 let client = this.client.clone();
146 let llm_api_token = this.llm_api_token.clone();
147 cx.spawn(async move |this, cx| {
148 llm_api_token.refresh(&client).await?;
149 let response = Self::fetch_models(client, llm_api_token).await?;
150 this.update(cx, |this, cx| {
151 this.update_models(response, cx);
152 })
153 })
154 .detach_and_log_err(cx);
155 },
156 ),
157 }
158 }
159
160 fn is_signed_out(&self, cx: &App) -> bool {
161 self.user_store.read(cx).current_user().is_none()
162 }
163
164 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
165 let client = self.client.clone();
166 cx.spawn(async move |state, cx| {
167 client.sign_in_with_optional_connect(true, cx).await?;
168 state.update(cx, |_, cx| cx.notify())
169 })
170 }
171 fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
172 let mut models = Vec::new();
173
174 for model in response.models {
175 models.push(Arc::new(model.clone()));
176
177 // Right now we represent thinking variants of models as separate models on the client,
178 // so we need to insert variants for any model that supports thinking.
179 if model.supports_thinking {
180 models.push(Arc::new(cloud_llm_client::LanguageModel {
181 id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
182 display_name: format!("{} Thinking", model.display_name),
183 ..model
184 }));
185 }
186 }
187
188 self.default_model = models
189 .iter()
190 .find(|model| {
191 response
192 .default_model
193 .as_ref()
194 .is_some_and(|default_model_id| &model.id == default_model_id)
195 })
196 .cloned();
197 self.default_fast_model = models
198 .iter()
199 .find(|model| {
200 response
201 .default_fast_model
202 .as_ref()
203 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
204 })
205 .cloned();
206 self.recommended_models = response
207 .recommended_models
208 .iter()
209 .filter_map(|id| models.iter().find(|model| &model.id == id))
210 .cloned()
211 .collect();
212 self.models = models;
213 cx.notify();
214 }
215
216 async fn fetch_models(
217 client: Arc<Client>,
218 llm_api_token: LlmApiToken,
219 ) -> Result<ListModelsResponse> {
220 let http_client = &client.http_client();
221 let token = llm_api_token.acquire(&client).await?;
222
223 let request = http_client::Request::builder()
224 .method(Method::GET)
225 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
226 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
227 .header("Authorization", format!("Bearer {token}"))
228 .body(AsyncBody::empty())?;
229 let mut response = http_client
230 .send(request)
231 .await
232 .context("failed to send list models request")?;
233
234 if response.status().is_success() {
235 let mut body = String::new();
236 response.body_mut().read_to_string(&mut body).await?;
237 Ok(serde_json::from_str(&body)?)
238 } else {
239 let mut body = String::new();
240 response.body_mut().read_to_string(&mut body).await?;
241 anyhow::bail!(
242 "error listing models.\nStatus: {:?}\nBody: {body}",
243 response.status(),
244 );
245 }
246 }
247}
248
249impl CloudLanguageModelProvider {
250 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
251 let mut status_rx = client.status();
252 let status = *status_rx.borrow();
253
254 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
255
256 let state_ref = state.downgrade();
257 let maintain_client_status = cx.spawn(async move |cx| {
258 while let Some(status) = status_rx.next().await {
259 if let Some(this) = state_ref.upgrade() {
260 _ = this.update(cx, |this, cx| {
261 if this.status != status {
262 this.status = status;
263 cx.notify();
264 }
265 });
266 } else {
267 break;
268 }
269 }
270 });
271
272 Self {
273 client,
274 state,
275 _maintain_client_status: maintain_client_status,
276 }
277 }
278
279 fn create_language_model(
280 &self,
281 model: Arc<cloud_llm_client::LanguageModel>,
282 llm_api_token: LlmApiToken,
283 ) -> Arc<dyn LanguageModel> {
284 Arc::new(CloudLanguageModel {
285 id: LanguageModelId(SharedString::from(model.id.0.clone())),
286 model,
287 llm_api_token,
288 client: self.client.clone(),
289 request_limiter: RateLimiter::new(4),
290 })
291 }
292}
293
294impl LanguageModelProviderState for CloudLanguageModelProvider {
295 type ObservableEntity = State;
296
297 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
298 Some(self.state.clone())
299 }
300}
301
302impl LanguageModelProvider for CloudLanguageModelProvider {
303 fn id(&self) -> LanguageModelProviderId {
304 PROVIDER_ID
305 }
306
307 fn name(&self) -> LanguageModelProviderName {
308 PROVIDER_NAME
309 }
310
311 fn icon(&self) -> IconOrSvg {
312 IconOrSvg::Icon(IconName::AiZed)
313 }
314
315 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
316 let default_model = self.state.read(cx).default_model.clone()?;
317 let llm_api_token = self.state.read(cx).llm_api_token.clone();
318 Some(self.create_language_model(default_model, llm_api_token))
319 }
320
321 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
322 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
323 let llm_api_token = self.state.read(cx).llm_api_token.clone();
324 Some(self.create_language_model(default_fast_model, llm_api_token))
325 }
326
327 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
328 let llm_api_token = self.state.read(cx).llm_api_token.clone();
329 self.state
330 .read(cx)
331 .recommended_models
332 .iter()
333 .cloned()
334 .map(|model| self.create_language_model(model, llm_api_token.clone()))
335 .collect()
336 }
337
338 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
339 let llm_api_token = self.state.read(cx).llm_api_token.clone();
340 self.state
341 .read(cx)
342 .models
343 .iter()
344 .cloned()
345 .map(|model| self.create_language_model(model, llm_api_token.clone()))
346 .collect()
347 }
348
349 fn is_authenticated(&self, cx: &App) -> bool {
350 let state = self.state.read(cx);
351 !state.is_signed_out(cx)
352 }
353
354 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
355 Task::ready(Ok(()))
356 }
357
358 fn configuration_view(
359 &self,
360 _target_agent: language_model::ConfigurationViewTargetAgent,
361 _: &mut Window,
362 cx: &mut App,
363 ) -> AnyView {
364 cx.new(|_| ConfigurationView::new(self.state.clone()))
365 .into()
366 }
367
368 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
369 Task::ready(Ok(()))
370 }
371}
372
373pub struct CloudLanguageModel {
374 id: LanguageModelId,
375 model: Arc<cloud_llm_client::LanguageModel>,
376 llm_api_token: LlmApiToken,
377 client: Arc<Client>,
378 request_limiter: RateLimiter,
379}
380
381struct PerformLlmCompletionResponse {
382 response: Response<AsyncBody>,
383 usage: Option<ModelRequestUsage>,
384 tool_use_limit_reached: bool,
385 includes_status_messages: bool,
386}
387
388impl CloudLanguageModel {
389 async fn perform_llm_completion(
390 client: Arc<Client>,
391 llm_api_token: LlmApiToken,
392 app_version: Option<Version>,
393 body: CompletionBody,
394 ) -> Result<PerformLlmCompletionResponse> {
395 let http_client = &client.http_client();
396
397 let mut token = llm_api_token.acquire(&client).await?;
398 let mut refreshed_token = false;
399
400 loop {
401 let request = http_client::Request::builder()
402 .method(Method::POST)
403 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
404 .when_some(app_version.as_ref(), |builder, app_version| {
405 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
406 })
407 .header("Content-Type", "application/json")
408 .header("Authorization", format!("Bearer {token}"))
409 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
410 .body(serde_json::to_string(&body)?.into())?;
411
412 let mut response = http_client.send(request).await?;
413 let status = response.status();
414 if status.is_success() {
415 let includes_status_messages = response
416 .headers()
417 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
418 .is_some();
419
420 let tool_use_limit_reached = response
421 .headers()
422 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
423 .is_some();
424
425 let usage = if includes_status_messages {
426 None
427 } else {
428 ModelRequestUsage::from_headers(response.headers()).ok()
429 };
430
431 return Ok(PerformLlmCompletionResponse {
432 response,
433 usage,
434 includes_status_messages,
435 tool_use_limit_reached,
436 });
437 }
438
439 if !refreshed_token
440 && response
441 .headers()
442 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
443 .is_some()
444 {
445 token = llm_api_token.refresh(&client).await?;
446 refreshed_token = true;
447 continue;
448 }
449
450 if status == StatusCode::FORBIDDEN
451 && response
452 .headers()
453 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
454 .is_some()
455 {
456 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
457 .headers()
458 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
459 .and_then(|resource| resource.to_str().ok())
460 && let Some(plan) = response
461 .headers()
462 .get(CURRENT_PLAN_HEADER_NAME)
463 .and_then(|plan| plan.to_str().ok())
464 .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
465 .map(Plan::V1)
466 {
467 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
468 }
469 } else if status == StatusCode::PAYMENT_REQUIRED {
470 return Err(anyhow!(PaymentRequiredError));
471 }
472
473 let mut body = String::new();
474 let headers = response.headers().clone();
475 response.body_mut().read_to_string(&mut body).await?;
476 return Err(anyhow!(ApiError {
477 status,
478 body,
479 headers
480 }));
481 }
482 }
483}
484
485#[derive(Debug, Error)]
486#[error("cloud language model request failed with status {status}: {body}")]
487struct ApiError {
488 status: StatusCode,
489 body: String,
490 headers: HeaderMap<HeaderValue>,
491}
492
493/// Represents error responses from Zed's cloud API.
494///
495/// Example JSON for an upstream HTTP error:
496/// ```json
497/// {
498/// "code": "upstream_http_error",
499/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
500/// "upstream_status": 503
501/// }
502/// ```
503#[derive(Debug, serde::Deserialize)]
504struct CloudApiError {
505 code: String,
506 message: String,
507 #[serde(default)]
508 #[serde(deserialize_with = "deserialize_optional_status_code")]
509 upstream_status: Option<StatusCode>,
510 #[serde(default)]
511 retry_after: Option<f64>,
512}
513
514fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
515where
516 D: serde::Deserializer<'de>,
517{
518 let opt: Option<u16> = Option::deserialize(deserializer)?;
519 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
520}
521
522impl From<ApiError> for LanguageModelCompletionError {
523 fn from(error: ApiError) -> Self {
524 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
525 if cloud_error.code.starts_with("upstream_http_") {
526 let status = if let Some(status) = cloud_error.upstream_status {
527 status
528 } else if cloud_error.code.ends_with("_error") {
529 error.status
530 } else {
531 // If there's a status code in the code string (e.g. "upstream_http_429")
532 // then use that; otherwise, see if the JSON contains a status code.
533 cloud_error
534 .code
535 .strip_prefix("upstream_http_")
536 .and_then(|code_str| code_str.parse::<u16>().ok())
537 .and_then(|code| StatusCode::from_u16(code).ok())
538 .unwrap_or(error.status)
539 };
540
541 return LanguageModelCompletionError::UpstreamProviderError {
542 message: cloud_error.message,
543 status,
544 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
545 };
546 }
547
548 return LanguageModelCompletionError::from_http_status(
549 PROVIDER_NAME,
550 error.status,
551 cloud_error.message,
552 None,
553 );
554 }
555
556 let retry_after = None;
557 LanguageModelCompletionError::from_http_status(
558 PROVIDER_NAME,
559 error.status,
560 error.body,
561 retry_after,
562 )
563 }
564}
565
566impl LanguageModel for CloudLanguageModel {
567 fn id(&self) -> LanguageModelId {
568 self.id.clone()
569 }
570
571 fn name(&self) -> LanguageModelName {
572 LanguageModelName::from(self.model.display_name.clone())
573 }
574
575 fn provider_id(&self) -> LanguageModelProviderId {
576 PROVIDER_ID
577 }
578
579 fn provider_name(&self) -> LanguageModelProviderName {
580 PROVIDER_NAME
581 }
582
583 fn upstream_provider_id(&self) -> LanguageModelProviderId {
584 use cloud_llm_client::LanguageModelProvider::*;
585 match self.model.provider {
586 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
587 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
588 Google => language_model::GOOGLE_PROVIDER_ID,
589 XAi => language_model::X_AI_PROVIDER_ID,
590 }
591 }
592
593 fn upstream_provider_name(&self) -> LanguageModelProviderName {
594 use cloud_llm_client::LanguageModelProvider::*;
595 match self.model.provider {
596 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
597 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
598 Google => language_model::GOOGLE_PROVIDER_NAME,
599 XAi => language_model::X_AI_PROVIDER_NAME,
600 }
601 }
602
603 fn supports_tools(&self) -> bool {
604 self.model.supports_tools
605 }
606
607 fn supports_images(&self) -> bool {
608 self.model.supports_images
609 }
610
611 fn supports_streaming_tools(&self) -> bool {
612 self.model.supports_streaming_tools
613 }
614
615 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
616 match choice {
617 LanguageModelToolChoice::Auto
618 | LanguageModelToolChoice::Any
619 | LanguageModelToolChoice::None => true,
620 }
621 }
622
623 fn supports_burn_mode(&self) -> bool {
624 self.model.supports_max_mode
625 }
626
627 fn telemetry_id(&self) -> String {
628 format!("zed.dev/{}", self.model.id)
629 }
630
631 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
632 match self.model.provider {
633 cloud_llm_client::LanguageModelProvider::Anthropic
634 | cloud_llm_client::LanguageModelProvider::OpenAi
635 | cloud_llm_client::LanguageModelProvider::XAi => {
636 LanguageModelToolSchemaFormat::JsonSchema
637 }
638 cloud_llm_client::LanguageModelProvider::Google => {
639 LanguageModelToolSchemaFormat::JsonSchemaSubset
640 }
641 }
642 }
643
644 fn max_token_count(&self) -> u64 {
645 self.model.max_token_count as u64
646 }
647
648 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
649 self.model
650 .max_token_count_in_max_mode
651 .filter(|_| self.model.supports_max_mode)
652 .map(|max_token_count| max_token_count as u64)
653 }
654
655 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
656 match &self.model.provider {
657 cloud_llm_client::LanguageModelProvider::Anthropic => {
658 Some(LanguageModelCacheConfiguration {
659 min_total_token: 2_048,
660 should_speculate: true,
661 max_cache_anchors: 4,
662 })
663 }
664 cloud_llm_client::LanguageModelProvider::OpenAi
665 | cloud_llm_client::LanguageModelProvider::XAi
666 | cloud_llm_client::LanguageModelProvider::Google => None,
667 }
668 }
669
670 fn count_tokens(
671 &self,
672 request: LanguageModelRequest,
673 cx: &App,
674 ) -> BoxFuture<'static, Result<u64>> {
675 match self.model.provider {
676 cloud_llm_client::LanguageModelProvider::Anthropic => cx
677 .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
678 .boxed(),
679 cloud_llm_client::LanguageModelProvider::OpenAi => {
680 let model = match open_ai::Model::from_id(&self.model.id.0) {
681 Ok(model) => model,
682 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
683 };
684 count_open_ai_tokens(request, model, cx)
685 }
686 cloud_llm_client::LanguageModelProvider::XAi => {
687 let model = match x_ai::Model::from_id(&self.model.id.0) {
688 Ok(model) => model,
689 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
690 };
691 count_xai_tokens(request, model, cx)
692 }
693 cloud_llm_client::LanguageModelProvider::Google => {
694 let client = self.client.clone();
695 let llm_api_token = self.llm_api_token.clone();
696 let model_id = self.model.id.to_string();
697 let generate_content_request =
698 into_google(request, model_id.clone(), GoogleModelMode::Default);
699 async move {
700 let http_client = &client.http_client();
701 let token = llm_api_token.acquire(&client).await?;
702
703 let request_body = CountTokensBody {
704 provider: cloud_llm_client::LanguageModelProvider::Google,
705 model: model_id,
706 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
707 generate_content_request,
708 })?,
709 };
710 let request = http_client::Request::builder()
711 .method(Method::POST)
712 .uri(
713 http_client
714 .build_zed_llm_url("/count_tokens", &[])?
715 .as_ref(),
716 )
717 .header("Content-Type", "application/json")
718 .header("Authorization", format!("Bearer {token}"))
719 .body(serde_json::to_string(&request_body)?.into())?;
720 let mut response = http_client.send(request).await?;
721 let status = response.status();
722 let headers = response.headers().clone();
723 let mut response_body = String::new();
724 response
725 .body_mut()
726 .read_to_string(&mut response_body)
727 .await?;
728
729 if status.is_success() {
730 let response_body: CountTokensResponse =
731 serde_json::from_str(&response_body)?;
732
733 Ok(response_body.tokens as u64)
734 } else {
735 Err(anyhow!(ApiError {
736 status,
737 body: response_body,
738 headers
739 }))
740 }
741 }
742 .boxed()
743 }
744 }
745 }
746
747 fn stream_completion(
748 &self,
749 request: LanguageModelRequest,
750 cx: &AsyncApp,
751 ) -> BoxFuture<
752 'static,
753 Result<
754 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
755 LanguageModelCompletionError,
756 >,
757 > {
758 let thread_id = request.thread_id.clone();
759 let prompt_id = request.prompt_id.clone();
760 let intent = request.intent;
761 let mode = request.mode;
762 let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
763 let use_responses_api = cx.update(|cx| cx.has_flag::<OpenAiResponsesApiFeatureFlag>());
764 let thinking_allowed = request.thinking_allowed;
765 let provider_name = provider_name(&self.model.provider);
766 match self.model.provider {
767 cloud_llm_client::LanguageModelProvider::Anthropic => {
768 let request = into_anthropic(
769 request,
770 self.model.id.to_string(),
771 1.0,
772 self.model.max_output_tokens as u64,
773 if thinking_allowed && self.model.id.0.ends_with("-thinking") {
774 AnthropicModelMode::Thinking {
775 budget_tokens: Some(4_096),
776 }
777 } else {
778 AnthropicModelMode::Default
779 },
780 );
781 let client = self.client.clone();
782 let llm_api_token = self.llm_api_token.clone();
783 let future = self.request_limiter.stream(async move {
784 let PerformLlmCompletionResponse {
785 response,
786 usage,
787 includes_status_messages,
788 tool_use_limit_reached,
789 } = Self::perform_llm_completion(
790 client.clone(),
791 llm_api_token,
792 app_version,
793 CompletionBody {
794 thread_id,
795 prompt_id,
796 intent,
797 mode,
798 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
799 model: request.model.clone(),
800 provider_request: serde_json::to_value(&request)
801 .map_err(|e| anyhow!(e))?,
802 },
803 )
804 .await
805 .map_err(|err| match err.downcast::<ApiError>() {
806 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
807 Err(err) => anyhow!(err),
808 })?;
809
810 let mut mapper = AnthropicEventMapper::new();
811 Ok(map_cloud_completion_events(
812 Box::pin(
813 response_lines(response, includes_status_messages)
814 .chain(usage_updated_event(usage))
815 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
816 ),
817 &provider_name,
818 move |event| mapper.map_event(event),
819 ))
820 });
821 async move { Ok(future.await?.boxed()) }.boxed()
822 }
823 cloud_llm_client::LanguageModelProvider::OpenAi => {
824 let client = self.client.clone();
825 let llm_api_token = self.llm_api_token.clone();
826
827 if use_responses_api {
828 let request = into_open_ai_response(
829 request,
830 &self.model.id.0,
831 self.model.supports_parallel_tool_calls,
832 true,
833 None,
834 None,
835 );
836 let future = self.request_limiter.stream(async move {
837 let PerformLlmCompletionResponse {
838 response,
839 usage,
840 includes_status_messages,
841 tool_use_limit_reached,
842 } = Self::perform_llm_completion(
843 client.clone(),
844 llm_api_token,
845 app_version,
846 CompletionBody {
847 thread_id,
848 prompt_id,
849 intent,
850 mode,
851 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
852 model: request.model.clone(),
853 provider_request: serde_json::to_value(&request)
854 .map_err(|e| anyhow!(e))?,
855 },
856 )
857 .await?;
858
859 let mut mapper = OpenAiResponseEventMapper::new();
860 Ok(map_cloud_completion_events(
861 Box::pin(
862 response_lines(response, includes_status_messages)
863 .chain(usage_updated_event(usage))
864 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
865 ),
866 &provider_name,
867 move |event| mapper.map_event(event),
868 ))
869 });
870 async move { Ok(future.await?.boxed()) }.boxed()
871 } else {
872 let request = into_open_ai(
873 request,
874 &self.model.id.0,
875 self.model.supports_parallel_tool_calls,
876 true,
877 None,
878 None,
879 );
880 let future = self.request_limiter.stream(async move {
881 let PerformLlmCompletionResponse {
882 response,
883 usage,
884 includes_status_messages,
885 tool_use_limit_reached,
886 } = Self::perform_llm_completion(
887 client.clone(),
888 llm_api_token,
889 app_version,
890 CompletionBody {
891 thread_id,
892 prompt_id,
893 intent,
894 mode,
895 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
896 model: request.model.clone(),
897 provider_request: serde_json::to_value(&request)
898 .map_err(|e| anyhow!(e))?,
899 },
900 )
901 .await?;
902
903 let mut mapper = OpenAiEventMapper::new();
904 Ok(map_cloud_completion_events(
905 Box::pin(
906 response_lines(response, includes_status_messages)
907 .chain(usage_updated_event(usage))
908 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
909 ),
910 &provider_name,
911 move |event| mapper.map_event(event),
912 ))
913 });
914 async move { Ok(future.await?.boxed()) }.boxed()
915 }
916 }
917 cloud_llm_client::LanguageModelProvider::XAi => {
918 let client = self.client.clone();
919 let request = into_open_ai(
920 request,
921 &self.model.id.0,
922 self.model.supports_parallel_tool_calls,
923 false,
924 None,
925 None,
926 );
927 let llm_api_token = self.llm_api_token.clone();
928 let future = self.request_limiter.stream(async move {
929 let PerformLlmCompletionResponse {
930 response,
931 usage,
932 includes_status_messages,
933 tool_use_limit_reached,
934 } = Self::perform_llm_completion(
935 client.clone(),
936 llm_api_token,
937 app_version,
938 CompletionBody {
939 thread_id,
940 prompt_id,
941 intent,
942 mode,
943 provider: cloud_llm_client::LanguageModelProvider::XAi,
944 model: request.model.clone(),
945 provider_request: serde_json::to_value(&request)
946 .map_err(|e| anyhow!(e))?,
947 },
948 )
949 .await?;
950
951 let mut mapper = OpenAiEventMapper::new();
952 Ok(map_cloud_completion_events(
953 Box::pin(
954 response_lines(response, includes_status_messages)
955 .chain(usage_updated_event(usage))
956 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
957 ),
958 &provider_name,
959 move |event| mapper.map_event(event),
960 ))
961 });
962 async move { Ok(future.await?.boxed()) }.boxed()
963 }
964 cloud_llm_client::LanguageModelProvider::Google => {
965 let client = self.client.clone();
966 let request =
967 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
968 let llm_api_token = self.llm_api_token.clone();
969 let future = self.request_limiter.stream(async move {
970 let PerformLlmCompletionResponse {
971 response,
972 usage,
973 includes_status_messages,
974 tool_use_limit_reached,
975 } = Self::perform_llm_completion(
976 client.clone(),
977 llm_api_token,
978 app_version,
979 CompletionBody {
980 thread_id,
981 prompt_id,
982 intent,
983 mode,
984 provider: cloud_llm_client::LanguageModelProvider::Google,
985 model: request.model.model_id.clone(),
986 provider_request: serde_json::to_value(&request)
987 .map_err(|e| anyhow!(e))?,
988 },
989 )
990 .await?;
991
992 let mut mapper = GoogleEventMapper::new();
993 Ok(map_cloud_completion_events(
994 Box::pin(
995 response_lines(response, includes_status_messages)
996 .chain(usage_updated_event(usage))
997 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
998 ),
999 &provider_name,
1000 move |event| mapper.map_event(event),
1001 ))
1002 });
1003 async move { Ok(future.await?.boxed()) }.boxed()
1004 }
1005 }
1006 }
1007}
1008
1009fn map_cloud_completion_events<T, F>(
1010 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
1011 provider: &LanguageModelProviderName,
1012 mut map_callback: F,
1013) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1014where
1015 T: DeserializeOwned + 'static,
1016 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1017 + Send
1018 + 'static,
1019{
1020 let provider = provider.clone();
1021 stream
1022 .flat_map(move |event| {
1023 futures::stream::iter(match event {
1024 Err(error) => {
1025 vec![Err(LanguageModelCompletionError::from(error))]
1026 }
1027 Ok(CompletionEvent::Status(event)) => {
1028 vec![
1029 LanguageModelCompletionEvent::from_completion_request_status(
1030 event,
1031 provider.clone(),
1032 ),
1033 ]
1034 }
1035 Ok(CompletionEvent::Event(event)) => map_callback(event),
1036 })
1037 })
1038 .boxed()
1039}
1040
1041fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1042 match provider {
1043 cloud_llm_client::LanguageModelProvider::Anthropic => {
1044 language_model::ANTHROPIC_PROVIDER_NAME
1045 }
1046 cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1047 cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1048 cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1049 }
1050}
1051
1052fn usage_updated_event<T>(
1053 usage: Option<ModelRequestUsage>,
1054) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1055 futures::stream::iter(usage.map(|usage| {
1056 Ok(CompletionEvent::Status(
1057 CompletionRequestStatus::UsageUpdated {
1058 amount: usage.amount as usize,
1059 limit: usage.limit,
1060 },
1061 ))
1062 }))
1063}
1064
1065fn tool_use_limit_reached_event<T>(
1066 tool_use_limit_reached: bool,
1067) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1068 futures::stream::iter(tool_use_limit_reached.then(|| {
1069 Ok(CompletionEvent::Status(
1070 CompletionRequestStatus::ToolUseLimitReached,
1071 ))
1072 }))
1073}
1074
1075fn response_lines<T: DeserializeOwned>(
1076 response: Response<AsyncBody>,
1077 includes_status_messages: bool,
1078) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1079 futures::stream::try_unfold(
1080 (String::new(), BufReader::new(response.into_body())),
1081 move |(mut line, mut body)| async move {
1082 match body.read_line(&mut line).await {
1083 Ok(0) => Ok(None),
1084 Ok(_) => {
1085 let event = if includes_status_messages {
1086 serde_json::from_str::<CompletionEvent<T>>(&line)?
1087 } else {
1088 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1089 };
1090
1091 line.clear();
1092 Ok(Some((event, (line, body))))
1093 }
1094 Err(e) => Err(e.into()),
1095 }
1096 },
1097 )
1098}
1099
1100#[derive(IntoElement, RegisterComponent)]
1101struct ZedAiConfiguration {
1102 is_connected: bool,
1103 plan: Option<Plan>,
1104 subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1105 eligible_for_trial: bool,
1106 account_too_young: bool,
1107 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1108}
1109
1110impl RenderOnce for ZedAiConfiguration {
1111 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1112 let is_pro = self.plan.is_some_and(|plan| {
1113 matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
1114 });
1115 let subscription_text = match (self.plan, self.subscription_period) {
1116 (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
1117 "You have access to Zed's hosted models through your Pro subscription."
1118 }
1119 (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
1120 "You have access to Zed's hosted models through your Pro trial."
1121 }
1122 (Some(Plan::V1(PlanV1::ZedFree)), Some(_)) => {
1123 "You have basic access to Zed's hosted models through the Free plan."
1124 }
1125 (Some(Plan::V2(PlanV2::ZedFree)), Some(_)) => {
1126 if self.eligible_for_trial {
1127 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1128 } else {
1129 "Subscribe for access to Zed's hosted models."
1130 }
1131 }
1132 _ => {
1133 if self.eligible_for_trial {
1134 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1135 } else {
1136 "Subscribe for access to Zed's hosted models."
1137 }
1138 }
1139 };
1140
1141 let manage_subscription_buttons = if is_pro {
1142 Button::new("manage_settings", "Manage Subscription")
1143 .full_width()
1144 .style(ButtonStyle::Tinted(TintColor::Accent))
1145 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1146 .into_any_element()
1147 } else if self.plan.is_none() || self.eligible_for_trial {
1148 Button::new("start_trial", "Start 14-day Free Pro Trial")
1149 .full_width()
1150 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1151 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1152 .into_any_element()
1153 } else {
1154 Button::new("upgrade", "Upgrade to Pro")
1155 .full_width()
1156 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1157 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1158 .into_any_element()
1159 };
1160
1161 if !self.is_connected {
1162 return v_flex()
1163 .gap_2()
1164 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1165 .child(
1166 Button::new("sign_in", "Sign In to use Zed AI")
1167 .icon_color(Color::Muted)
1168 .icon(IconName::Github)
1169 .icon_size(IconSize::Small)
1170 .icon_position(IconPosition::Start)
1171 .full_width()
1172 .on_click({
1173 let callback = self.sign_in_callback.clone();
1174 move |_, window, cx| (callback)(window, cx)
1175 }),
1176 );
1177 }
1178
1179 v_flex().gap_2().w_full().map(|this| {
1180 if self.account_too_young {
1181 this.child(YoungAccountBanner).child(
1182 Button::new("upgrade", "Upgrade to Pro")
1183 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1184 .full_width()
1185 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1186 )
1187 } else {
1188 this.text_sm()
1189 .child(subscription_text)
1190 .child(manage_subscription_buttons)
1191 }
1192 })
1193 }
1194}
1195
1196struct ConfigurationView {
1197 state: Entity<State>,
1198 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1199}
1200
1201impl ConfigurationView {
1202 fn new(state: Entity<State>) -> Self {
1203 let sign_in_callback = Arc::new({
1204 let state = state.clone();
1205 move |_window: &mut Window, cx: &mut App| {
1206 state.update(cx, |state, cx| {
1207 state.authenticate(cx).detach_and_log_err(cx);
1208 });
1209 }
1210 });
1211
1212 Self {
1213 state,
1214 sign_in_callback,
1215 }
1216 }
1217}
1218
1219impl Render for ConfigurationView {
1220 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1221 let state = self.state.read(cx);
1222 let user_store = state.user_store.read(cx);
1223
1224 ZedAiConfiguration {
1225 is_connected: !state.is_signed_out(cx),
1226 plan: user_store.plan(),
1227 subscription_period: user_store.subscription_period(),
1228 eligible_for_trial: user_store.trial_started_at().is_none(),
1229 account_too_young: user_store.account_too_young(),
1230 sign_in_callback: self.sign_in_callback.clone(),
1231 }
1232 }
1233}
1234
1235impl Component for ZedAiConfiguration {
1236 fn name() -> &'static str {
1237 "AI Configuration Content"
1238 }
1239
1240 fn sort_name() -> &'static str {
1241 "AI Configuration Content"
1242 }
1243
1244 fn scope() -> ComponentScope {
1245 ComponentScope::Onboarding
1246 }
1247
1248 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1249 fn configuration(
1250 is_connected: bool,
1251 plan: Option<Plan>,
1252 eligible_for_trial: bool,
1253 account_too_young: bool,
1254 ) -> AnyElement {
1255 ZedAiConfiguration {
1256 is_connected,
1257 plan,
1258 subscription_period: plan
1259 .is_some()
1260 .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1261 eligible_for_trial,
1262 account_too_young,
1263 sign_in_callback: Arc::new(|_, _| {}),
1264 }
1265 .into_any_element()
1266 }
1267
1268 Some(
1269 v_flex()
1270 .p_4()
1271 .gap_4()
1272 .children(vec![
1273 single_example("Not connected", configuration(false, None, false, false)),
1274 single_example(
1275 "Accept Terms of Service",
1276 configuration(true, None, true, false),
1277 ),
1278 single_example(
1279 "No Plan - Not eligible for trial",
1280 configuration(true, None, false, false),
1281 ),
1282 single_example(
1283 "No Plan - Eligible for trial",
1284 configuration(true, None, true, false),
1285 ),
1286 single_example(
1287 "Free Plan",
1288 configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1289 ),
1290 single_example(
1291 "Zed Pro Trial Plan",
1292 configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1293 ),
1294 single_example(
1295 "Zed Pro Plan",
1296 configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1297 ),
1298 ])
1299 .into_any_element(),
1300 )
1301 }
1302}
1303
1304#[cfg(test)]
1305mod tests {
1306 use super::*;
1307 use http_client::http::{HeaderMap, StatusCode};
1308 use language_model::LanguageModelCompletionError;
1309
1310 #[test]
1311 fn test_api_error_conversion_with_upstream_http_error() {
1312 // upstream_http_error with 503 status should become ServerOverloaded
1313 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1314
1315 let api_error = ApiError {
1316 status: StatusCode::INTERNAL_SERVER_ERROR,
1317 body: error_body.to_string(),
1318 headers: HeaderMap::new(),
1319 };
1320
1321 let completion_error: LanguageModelCompletionError = api_error.into();
1322
1323 match completion_error {
1324 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1325 assert_eq!(
1326 message,
1327 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1328 );
1329 }
1330 _ => panic!(
1331 "Expected UpstreamProviderError for upstream 503, got: {:?}",
1332 completion_error
1333 ),
1334 }
1335
1336 // upstream_http_error with 500 status should become ApiInternalServerError
1337 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1338
1339 let api_error = ApiError {
1340 status: StatusCode::INTERNAL_SERVER_ERROR,
1341 body: error_body.to_string(),
1342 headers: HeaderMap::new(),
1343 };
1344
1345 let completion_error: LanguageModelCompletionError = api_error.into();
1346
1347 match completion_error {
1348 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1349 assert_eq!(
1350 message,
1351 "Received an error from the OpenAI API: internal server error"
1352 );
1353 }
1354 _ => panic!(
1355 "Expected UpstreamProviderError for upstream 500, got: {:?}",
1356 completion_error
1357 ),
1358 }
1359
1360 // upstream_http_error with 429 status should become RateLimitExceeded
1361 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1362
1363 let api_error = ApiError {
1364 status: StatusCode::INTERNAL_SERVER_ERROR,
1365 body: error_body.to_string(),
1366 headers: HeaderMap::new(),
1367 };
1368
1369 let completion_error: LanguageModelCompletionError = api_error.into();
1370
1371 match completion_error {
1372 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1373 assert_eq!(
1374 message,
1375 "Received an error from the Google API: rate limit exceeded"
1376 );
1377 }
1378 _ => panic!(
1379 "Expected UpstreamProviderError for upstream 429, got: {:?}",
1380 completion_error
1381 ),
1382 }
1383
1384 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1385 let error_body = "Regular internal server error";
1386
1387 let api_error = ApiError {
1388 status: StatusCode::INTERNAL_SERVER_ERROR,
1389 body: error_body.to_string(),
1390 headers: HeaderMap::new(),
1391 };
1392
1393 let completion_error: LanguageModelCompletionError = api_error.into();
1394
1395 match completion_error {
1396 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1397 assert_eq!(provider, PROVIDER_NAME);
1398 assert_eq!(message, "Regular internal server error");
1399 }
1400 _ => panic!(
1401 "Expected ApiInternalServerError for regular 500, got: {:?}",
1402 completion_error
1403 ),
1404 }
1405
1406 // upstream_http_429 format should be converted to UpstreamProviderError
1407 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1408
1409 let api_error = ApiError {
1410 status: StatusCode::INTERNAL_SERVER_ERROR,
1411 body: error_body.to_string(),
1412 headers: HeaderMap::new(),
1413 };
1414
1415 let completion_error: LanguageModelCompletionError = api_error.into();
1416
1417 match completion_error {
1418 LanguageModelCompletionError::UpstreamProviderError {
1419 message,
1420 status,
1421 retry_after,
1422 } => {
1423 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1424 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1425 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1426 }
1427 _ => panic!(
1428 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1429 completion_error
1430 ),
1431 }
1432
1433 // Invalid JSON in error body should fall back to regular error handling
1434 let error_body = "Not JSON at all";
1435
1436 let api_error = ApiError {
1437 status: StatusCode::INTERNAL_SERVER_ERROR,
1438 body: error_body.to_string(),
1439 headers: HeaderMap::new(),
1440 };
1441
1442 let completion_error: LanguageModelCompletionError = api_error.into();
1443
1444 match completion_error {
1445 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1446 assert_eq!(provider, PROVIDER_NAME);
1447 }
1448 _ => panic!(
1449 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1450 completion_error
1451 ),
1452 }
1453 }
1454}