1use ai_onboarding::YoungAccountBanner;
2use anthropic::AnthropicModelMode;
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use client::{Client, ModelRequestUsage, UserStore, zed_urls};
6use cloud_llm_client::{
7 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME,
8 CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
9 CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
10 MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan, PlanV1, PlanV2,
11 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
12 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
13};
14use feature_flags::{FeatureFlagAppExt as _, OpenAiResponsesApiFeatureFlag};
15use futures::{
16 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
17};
18use google_ai::GoogleModelMode;
19use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
20use http_client::http::{HeaderMap, HeaderValue};
21use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
22use language_model::{
23 AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
24 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
25 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
26 LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
27 LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError,
28 PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
29};
30use release_channel::AppVersion;
31use schemars::JsonSchema;
32use semver::Version;
33use serde::{Deserialize, Serialize, de::DeserializeOwned};
34use settings::SettingsStore;
35pub use settings::ZedDotDevAvailableModel as AvailableModel;
36pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
37use smol::io::{AsyncReadExt, BufReader};
38use std::pin::Pin;
39use std::str::FromStr as _;
40use std::sync::Arc;
41use std::time::Duration;
42use thiserror::Error;
43use ui::{TintColor, prelude::*};
44use util::{ResultExt as _, maybe};
45
46use crate::provider::anthropic::{
47 AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
48};
49use crate::provider::google::{GoogleEventMapper, into_google};
50use crate::provider::open_ai::{
51 OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
52 into_open_ai_response,
53};
54use crate::provider::x_ai::count_xai_tokens;
55
56const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
57const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
58
59#[derive(Default, Clone, Debug, PartialEq)]
60pub struct ZedDotDevSettings {
61 pub available_models: Vec<AvailableModel>,
62}
63#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
64#[serde(tag = "type", rename_all = "lowercase")]
65pub enum ModelMode {
66 #[default]
67 Default,
68 Thinking {
69 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
70 budget_tokens: Option<u32>,
71 },
72}
73
74impl From<ModelMode> for AnthropicModelMode {
75 fn from(value: ModelMode) -> Self {
76 match value {
77 ModelMode::Default => AnthropicModelMode::Default,
78 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
79 }
80 }
81}
82
83pub struct CloudLanguageModelProvider {
84 client: Arc<Client>,
85 state: Entity<State>,
86 _maintain_client_status: Task<()>,
87}
88
89pub struct State {
90 client: Arc<Client>,
91 llm_api_token: LlmApiToken,
92 user_store: Entity<UserStore>,
93 status: client::Status,
94 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
95 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
96 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
97 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
98 _fetch_models_task: Task<()>,
99 _settings_subscription: Subscription,
100 _llm_token_subscription: Subscription,
101}
102
103impl State {
104 fn new(
105 client: Arc<Client>,
106 user_store: Entity<UserStore>,
107 status: client::Status,
108 cx: &mut Context<Self>,
109 ) -> Self {
110 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
111 let mut current_user = user_store.read(cx).watch_current_user();
112 Self {
113 client: client.clone(),
114 llm_api_token: LlmApiToken::default(),
115 user_store,
116 status,
117 models: Vec::new(),
118 default_model: None,
119 default_fast_model: None,
120 recommended_models: Vec::new(),
121 _fetch_models_task: cx.spawn(async move |this, cx| {
122 maybe!(async move {
123 let (client, llm_api_token) = this
124 .read_with(cx, |this, _cx| (client.clone(), this.llm_api_token.clone()))?;
125
126 while current_user.borrow().is_none() {
127 current_user.next().await;
128 }
129
130 let response =
131 Self::fetch_models(client.clone(), llm_api_token.clone()).await?;
132 this.update(cx, |this, cx| this.update_models(response, cx))?;
133 anyhow::Ok(())
134 })
135 .await
136 .context("failed to fetch Zed models")
137 .log_err();
138 }),
139 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
140 cx.notify();
141 }),
142 _llm_token_subscription: cx.subscribe(
143 &refresh_llm_token_listener,
144 move |this, _listener, _event, cx| {
145 let client = this.client.clone();
146 let llm_api_token = this.llm_api_token.clone();
147 cx.spawn(async move |this, cx| {
148 llm_api_token.refresh(&client).await?;
149 let response = Self::fetch_models(client, llm_api_token).await?;
150 this.update(cx, |this, cx| {
151 this.update_models(response, cx);
152 })
153 })
154 .detach_and_log_err(cx);
155 },
156 ),
157 }
158 }
159
160 fn is_signed_out(&self, cx: &App) -> bool {
161 self.user_store.read(cx).current_user().is_none()
162 }
163
164 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
165 let client = self.client.clone();
166 cx.spawn(async move |state, cx| {
167 client.sign_in_with_optional_connect(true, cx).await?;
168 state.update(cx, |_, cx| cx.notify())
169 })
170 }
171 fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
172 let mut models = Vec::new();
173
174 for model in response.models {
175 models.push(Arc::new(model.clone()));
176
177 // Right now we represent thinking variants of models as separate models on the client,
178 // so we need to insert variants for any model that supports thinking.
179 if model.supports_thinking {
180 models.push(Arc::new(cloud_llm_client::LanguageModel {
181 id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
182 display_name: format!("{} Thinking", model.display_name),
183 ..model
184 }));
185 }
186 }
187
188 self.default_model = models
189 .iter()
190 .find(|model| {
191 response
192 .default_model
193 .as_ref()
194 .is_some_and(|default_model_id| &model.id == default_model_id)
195 })
196 .cloned();
197 self.default_fast_model = models
198 .iter()
199 .find(|model| {
200 response
201 .default_fast_model
202 .as_ref()
203 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
204 })
205 .cloned();
206 self.recommended_models = response
207 .recommended_models
208 .iter()
209 .filter_map(|id| models.iter().find(|model| &model.id == id))
210 .cloned()
211 .collect();
212 self.models = models;
213 cx.notify();
214 }
215
216 async fn fetch_models(
217 client: Arc<Client>,
218 llm_api_token: LlmApiToken,
219 ) -> Result<ListModelsResponse> {
220 let http_client = &client.http_client();
221 let token = llm_api_token.acquire(&client).await?;
222
223 let request = http_client::Request::builder()
224 .method(Method::GET)
225 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
226 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
227 .header("Authorization", format!("Bearer {token}"))
228 .body(AsyncBody::empty())?;
229 let mut response = http_client
230 .send(request)
231 .await
232 .context("failed to send list models request")?;
233
234 if response.status().is_success() {
235 let mut body = String::new();
236 response.body_mut().read_to_string(&mut body).await?;
237 Ok(serde_json::from_str(&body)?)
238 } else {
239 let mut body = String::new();
240 response.body_mut().read_to_string(&mut body).await?;
241 anyhow::bail!(
242 "error listing models.\nStatus: {:?}\nBody: {body}",
243 response.status(),
244 );
245 }
246 }
247}
248
249impl CloudLanguageModelProvider {
250 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
251 let mut status_rx = client.status();
252 let status = *status_rx.borrow();
253
254 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
255
256 let state_ref = state.downgrade();
257 let maintain_client_status = cx.spawn(async move |cx| {
258 while let Some(status) = status_rx.next().await {
259 if let Some(this) = state_ref.upgrade() {
260 _ = this.update(cx, |this, cx| {
261 if this.status != status {
262 this.status = status;
263 cx.notify();
264 }
265 });
266 } else {
267 break;
268 }
269 }
270 });
271
272 Self {
273 client,
274 state,
275 _maintain_client_status: maintain_client_status,
276 }
277 }
278
279 fn create_language_model(
280 &self,
281 model: Arc<cloud_llm_client::LanguageModel>,
282 llm_api_token: LlmApiToken,
283 ) -> Arc<dyn LanguageModel> {
284 Arc::new(CloudLanguageModel {
285 id: LanguageModelId(SharedString::from(model.id.0.clone())),
286 model,
287 llm_api_token,
288 client: self.client.clone(),
289 request_limiter: RateLimiter::new(4),
290 })
291 }
292}
293
294impl LanguageModelProviderState for CloudLanguageModelProvider {
295 type ObservableEntity = State;
296
297 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
298 Some(self.state.clone())
299 }
300}
301
302impl LanguageModelProvider for CloudLanguageModelProvider {
303 fn id(&self) -> LanguageModelProviderId {
304 PROVIDER_ID
305 }
306
307 fn name(&self) -> LanguageModelProviderName {
308 PROVIDER_NAME
309 }
310
311 fn icon(&self) -> IconOrSvg {
312 IconOrSvg::Icon(IconName::AiZed)
313 }
314
315 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
316 let default_model = self.state.read(cx).default_model.clone()?;
317 let llm_api_token = self.state.read(cx).llm_api_token.clone();
318 Some(self.create_language_model(default_model, llm_api_token))
319 }
320
321 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
322 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
323 let llm_api_token = self.state.read(cx).llm_api_token.clone();
324 Some(self.create_language_model(default_fast_model, llm_api_token))
325 }
326
327 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
328 let llm_api_token = self.state.read(cx).llm_api_token.clone();
329 self.state
330 .read(cx)
331 .recommended_models
332 .iter()
333 .cloned()
334 .map(|model| self.create_language_model(model, llm_api_token.clone()))
335 .collect()
336 }
337
338 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
339 let llm_api_token = self.state.read(cx).llm_api_token.clone();
340 self.state
341 .read(cx)
342 .models
343 .iter()
344 .cloned()
345 .map(|model| self.create_language_model(model, llm_api_token.clone()))
346 .collect()
347 }
348
349 fn is_authenticated(&self, cx: &App) -> bool {
350 let state = self.state.read(cx);
351 !state.is_signed_out(cx)
352 }
353
354 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
355 Task::ready(Ok(()))
356 }
357
358 fn configuration_view(
359 &self,
360 _target_agent: language_model::ConfigurationViewTargetAgent,
361 _: &mut Window,
362 cx: &mut App,
363 ) -> AnyView {
364 cx.new(|_| ConfigurationView::new(self.state.clone()))
365 .into()
366 }
367
368 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
369 Task::ready(Ok(()))
370 }
371}
372
373pub struct CloudLanguageModel {
374 id: LanguageModelId,
375 model: Arc<cloud_llm_client::LanguageModel>,
376 llm_api_token: LlmApiToken,
377 client: Arc<Client>,
378 request_limiter: RateLimiter,
379}
380
381struct PerformLlmCompletionResponse {
382 response: Response<AsyncBody>,
383 usage: Option<ModelRequestUsage>,
384 tool_use_limit_reached: bool,
385 includes_status_messages: bool,
386}
387
388impl CloudLanguageModel {
389 async fn perform_llm_completion(
390 client: Arc<Client>,
391 llm_api_token: LlmApiToken,
392 app_version: Option<Version>,
393 body: CompletionBody,
394 ) -> Result<PerformLlmCompletionResponse> {
395 let http_client = &client.http_client();
396
397 let mut token = llm_api_token.acquire(&client).await?;
398 let mut refreshed_token = false;
399
400 loop {
401 let request = http_client::Request::builder()
402 .method(Method::POST)
403 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
404 .when_some(app_version.as_ref(), |builder, app_version| {
405 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
406 })
407 .header("Content-Type", "application/json")
408 .header("Authorization", format!("Bearer {token}"))
409 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
410 .body(serde_json::to_string(&body)?.into())?;
411
412 let mut response = http_client.send(request).await?;
413 let status = response.status();
414 if status.is_success() {
415 let includes_status_messages = response
416 .headers()
417 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
418 .is_some();
419
420 let tool_use_limit_reached = response
421 .headers()
422 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
423 .is_some();
424
425 let usage = if includes_status_messages {
426 None
427 } else {
428 ModelRequestUsage::from_headers(response.headers()).ok()
429 };
430
431 return Ok(PerformLlmCompletionResponse {
432 response,
433 usage,
434 includes_status_messages,
435 tool_use_limit_reached,
436 });
437 }
438
439 if !refreshed_token
440 && response
441 .headers()
442 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
443 .is_some()
444 {
445 token = llm_api_token.refresh(&client).await?;
446 refreshed_token = true;
447 continue;
448 }
449
450 if status == StatusCode::FORBIDDEN
451 && response
452 .headers()
453 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
454 .is_some()
455 {
456 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
457 .headers()
458 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
459 .and_then(|resource| resource.to_str().ok())
460 && let Some(plan) = response
461 .headers()
462 .get(CURRENT_PLAN_HEADER_NAME)
463 .and_then(|plan| plan.to_str().ok())
464 .and_then(|plan| cloud_llm_client::PlanV1::from_str(plan).ok())
465 .map(Plan::V1)
466 {
467 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
468 }
469 } else if status == StatusCode::PAYMENT_REQUIRED {
470 return Err(anyhow!(PaymentRequiredError));
471 }
472
473 let mut body = String::new();
474 let headers = response.headers().clone();
475 response.body_mut().read_to_string(&mut body).await?;
476 return Err(anyhow!(ApiError {
477 status,
478 body,
479 headers
480 }));
481 }
482 }
483}
484
485#[derive(Debug, Error)]
486#[error("cloud language model request failed with status {status}: {body}")]
487struct ApiError {
488 status: StatusCode,
489 body: String,
490 headers: HeaderMap<HeaderValue>,
491}
492
493/// Represents error responses from Zed's cloud API.
494///
495/// Example JSON for an upstream HTTP error:
496/// ```json
497/// {
498/// "code": "upstream_http_error",
499/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
500/// "upstream_status": 503
501/// }
502/// ```
503#[derive(Debug, serde::Deserialize)]
504struct CloudApiError {
505 code: String,
506 message: String,
507 #[serde(default)]
508 #[serde(deserialize_with = "deserialize_optional_status_code")]
509 upstream_status: Option<StatusCode>,
510 #[serde(default)]
511 retry_after: Option<f64>,
512}
513
514fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
515where
516 D: serde::Deserializer<'de>,
517{
518 let opt: Option<u16> = Option::deserialize(deserializer)?;
519 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
520}
521
522impl From<ApiError> for LanguageModelCompletionError {
523 fn from(error: ApiError) -> Self {
524 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
525 if cloud_error.code.starts_with("upstream_http_") {
526 let status = if let Some(status) = cloud_error.upstream_status {
527 status
528 } else if cloud_error.code.ends_with("_error") {
529 error.status
530 } else {
531 // If there's a status code in the code string (e.g. "upstream_http_429")
532 // then use that; otherwise, see if the JSON contains a status code.
533 cloud_error
534 .code
535 .strip_prefix("upstream_http_")
536 .and_then(|code_str| code_str.parse::<u16>().ok())
537 .and_then(|code| StatusCode::from_u16(code).ok())
538 .unwrap_or(error.status)
539 };
540
541 return LanguageModelCompletionError::UpstreamProviderError {
542 message: cloud_error.message,
543 status,
544 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
545 };
546 }
547
548 return LanguageModelCompletionError::from_http_status(
549 PROVIDER_NAME,
550 error.status,
551 cloud_error.message,
552 None,
553 );
554 }
555
556 let retry_after = None;
557 LanguageModelCompletionError::from_http_status(
558 PROVIDER_NAME,
559 error.status,
560 error.body,
561 retry_after,
562 )
563 }
564}
565
566impl LanguageModel for CloudLanguageModel {
567 fn id(&self) -> LanguageModelId {
568 self.id.clone()
569 }
570
571 fn name(&self) -> LanguageModelName {
572 LanguageModelName::from(self.model.display_name.clone())
573 }
574
575 fn provider_id(&self) -> LanguageModelProviderId {
576 PROVIDER_ID
577 }
578
579 fn provider_name(&self) -> LanguageModelProviderName {
580 PROVIDER_NAME
581 }
582
583 fn upstream_provider_id(&self) -> LanguageModelProviderId {
584 use cloud_llm_client::LanguageModelProvider::*;
585 match self.model.provider {
586 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
587 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
588 Google => language_model::GOOGLE_PROVIDER_ID,
589 XAi => language_model::X_AI_PROVIDER_ID,
590 }
591 }
592
593 fn upstream_provider_name(&self) -> LanguageModelProviderName {
594 use cloud_llm_client::LanguageModelProvider::*;
595 match self.model.provider {
596 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
597 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
598 Google => language_model::GOOGLE_PROVIDER_NAME,
599 XAi => language_model::X_AI_PROVIDER_NAME,
600 }
601 }
602
603 fn supports_tools(&self) -> bool {
604 self.model.supports_tools
605 }
606
607 fn supports_images(&self) -> bool {
608 self.model.supports_images
609 }
610
611 fn supports_streaming_tools(&self) -> bool {
612 self.model.supports_streaming_tools
613 }
614
615 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
616 match choice {
617 LanguageModelToolChoice::Auto
618 | LanguageModelToolChoice::Any
619 | LanguageModelToolChoice::None => true,
620 }
621 }
622
623 fn supports_burn_mode(&self) -> bool {
624 self.model.supports_max_mode
625 }
626
627 fn supports_split_token_display(&self) -> bool {
628 use cloud_llm_client::LanguageModelProvider::*;
629 matches!(self.model.provider, OpenAi)
630 }
631
632 fn telemetry_id(&self) -> String {
633 format!("zed.dev/{}", self.model.id)
634 }
635
636 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
637 match self.model.provider {
638 cloud_llm_client::LanguageModelProvider::Anthropic
639 | cloud_llm_client::LanguageModelProvider::OpenAi
640 | cloud_llm_client::LanguageModelProvider::XAi => {
641 LanguageModelToolSchemaFormat::JsonSchema
642 }
643 cloud_llm_client::LanguageModelProvider::Google => {
644 LanguageModelToolSchemaFormat::JsonSchemaSubset
645 }
646 }
647 }
648
649 fn max_token_count(&self) -> u64 {
650 self.model.max_token_count as u64
651 }
652
653 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
654 self.model
655 .max_token_count_in_max_mode
656 .filter(|_| self.model.supports_max_mode)
657 .map(|max_token_count| max_token_count as u64)
658 }
659
660 fn max_output_tokens(&self) -> Option<u64> {
661 Some(self.model.max_output_tokens as u64)
662 }
663
664 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
665 match &self.model.provider {
666 cloud_llm_client::LanguageModelProvider::Anthropic => {
667 Some(LanguageModelCacheConfiguration {
668 min_total_token: 2_048,
669 should_speculate: true,
670 max_cache_anchors: 4,
671 })
672 }
673 cloud_llm_client::LanguageModelProvider::OpenAi
674 | cloud_llm_client::LanguageModelProvider::XAi
675 | cloud_llm_client::LanguageModelProvider::Google => None,
676 }
677 }
678
679 fn count_tokens(
680 &self,
681 request: LanguageModelRequest,
682 cx: &App,
683 ) -> BoxFuture<'static, Result<u64>> {
684 match self.model.provider {
685 cloud_llm_client::LanguageModelProvider::Anthropic => cx
686 .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
687 .boxed(),
688 cloud_llm_client::LanguageModelProvider::OpenAi => {
689 let model = match open_ai::Model::from_id(&self.model.id.0) {
690 Ok(model) => model,
691 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
692 };
693 count_open_ai_tokens(request, model, cx)
694 }
695 cloud_llm_client::LanguageModelProvider::XAi => {
696 let model = match x_ai::Model::from_id(&self.model.id.0) {
697 Ok(model) => model,
698 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
699 };
700 count_xai_tokens(request, model, cx)
701 }
702 cloud_llm_client::LanguageModelProvider::Google => {
703 let client = self.client.clone();
704 let llm_api_token = self.llm_api_token.clone();
705 let model_id = self.model.id.to_string();
706 let generate_content_request =
707 into_google(request, model_id.clone(), GoogleModelMode::Default);
708 async move {
709 let http_client = &client.http_client();
710 let token = llm_api_token.acquire(&client).await?;
711
712 let request_body = CountTokensBody {
713 provider: cloud_llm_client::LanguageModelProvider::Google,
714 model: model_id,
715 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
716 generate_content_request,
717 })?,
718 };
719 let request = http_client::Request::builder()
720 .method(Method::POST)
721 .uri(
722 http_client
723 .build_zed_llm_url("/count_tokens", &[])?
724 .as_ref(),
725 )
726 .header("Content-Type", "application/json")
727 .header("Authorization", format!("Bearer {token}"))
728 .body(serde_json::to_string(&request_body)?.into())?;
729 let mut response = http_client.send(request).await?;
730 let status = response.status();
731 let headers = response.headers().clone();
732 let mut response_body = String::new();
733 response
734 .body_mut()
735 .read_to_string(&mut response_body)
736 .await?;
737
738 if status.is_success() {
739 let response_body: CountTokensResponse =
740 serde_json::from_str(&response_body)?;
741
742 Ok(response_body.tokens as u64)
743 } else {
744 Err(anyhow!(ApiError {
745 status,
746 body: response_body,
747 headers
748 }))
749 }
750 }
751 .boxed()
752 }
753 }
754 }
755
756 fn stream_completion(
757 &self,
758 request: LanguageModelRequest,
759 cx: &AsyncApp,
760 ) -> BoxFuture<
761 'static,
762 Result<
763 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
764 LanguageModelCompletionError,
765 >,
766 > {
767 let thread_id = request.thread_id.clone();
768 let prompt_id = request.prompt_id.clone();
769 let intent = request.intent;
770 let mode = request.mode;
771 let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
772 let use_responses_api = cx.update(|cx| cx.has_flag::<OpenAiResponsesApiFeatureFlag>());
773 let thinking_allowed = request.thinking_allowed;
774 let provider_name = provider_name(&self.model.provider);
775 match self.model.provider {
776 cloud_llm_client::LanguageModelProvider::Anthropic => {
777 let request = into_anthropic(
778 request,
779 self.model.id.to_string(),
780 1.0,
781 self.model.max_output_tokens as u64,
782 if thinking_allowed && self.model.id.0.ends_with("-thinking") {
783 AnthropicModelMode::Thinking {
784 budget_tokens: Some(4_096),
785 }
786 } else {
787 AnthropicModelMode::Default
788 },
789 );
790 let client = self.client.clone();
791 let llm_api_token = self.llm_api_token.clone();
792 let future = self.request_limiter.stream(async move {
793 let PerformLlmCompletionResponse {
794 response,
795 usage,
796 includes_status_messages,
797 tool_use_limit_reached,
798 } = Self::perform_llm_completion(
799 client.clone(),
800 llm_api_token,
801 app_version,
802 CompletionBody {
803 thread_id,
804 prompt_id,
805 intent,
806 mode,
807 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
808 model: request.model.clone(),
809 provider_request: serde_json::to_value(&request)
810 .map_err(|e| anyhow!(e))?,
811 },
812 )
813 .await
814 .map_err(|err| match err.downcast::<ApiError>() {
815 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
816 Err(err) => anyhow!(err),
817 })?;
818
819 let mut mapper = AnthropicEventMapper::new();
820 Ok(map_cloud_completion_events(
821 Box::pin(
822 response_lines(response, includes_status_messages)
823 .chain(usage_updated_event(usage))
824 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
825 ),
826 &provider_name,
827 move |event| mapper.map_event(event),
828 ))
829 });
830 async move { Ok(future.await?.boxed()) }.boxed()
831 }
832 cloud_llm_client::LanguageModelProvider::OpenAi => {
833 let client = self.client.clone();
834 let llm_api_token = self.llm_api_token.clone();
835
836 if use_responses_api {
837 let request = into_open_ai_response(
838 request,
839 &self.model.id.0,
840 self.model.supports_parallel_tool_calls,
841 true,
842 None,
843 None,
844 );
845 let future = self.request_limiter.stream(async move {
846 let PerformLlmCompletionResponse {
847 response,
848 usage,
849 includes_status_messages,
850 tool_use_limit_reached,
851 } = Self::perform_llm_completion(
852 client.clone(),
853 llm_api_token,
854 app_version,
855 CompletionBody {
856 thread_id,
857 prompt_id,
858 intent,
859 mode,
860 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
861 model: request.model.clone(),
862 provider_request: serde_json::to_value(&request)
863 .map_err(|e| anyhow!(e))?,
864 },
865 )
866 .await?;
867
868 let mut mapper = OpenAiResponseEventMapper::new();
869 Ok(map_cloud_completion_events(
870 Box::pin(
871 response_lines(response, includes_status_messages)
872 .chain(usage_updated_event(usage))
873 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
874 ),
875 &provider_name,
876 move |event| mapper.map_event(event),
877 ))
878 });
879 async move { Ok(future.await?.boxed()) }.boxed()
880 } else {
881 let request = into_open_ai(
882 request,
883 &self.model.id.0,
884 self.model.supports_parallel_tool_calls,
885 true,
886 None,
887 None,
888 );
889 let future = self.request_limiter.stream(async move {
890 let PerformLlmCompletionResponse {
891 response,
892 usage,
893 includes_status_messages,
894 tool_use_limit_reached,
895 } = Self::perform_llm_completion(
896 client.clone(),
897 llm_api_token,
898 app_version,
899 CompletionBody {
900 thread_id,
901 prompt_id,
902 intent,
903 mode,
904 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
905 model: request.model.clone(),
906 provider_request: serde_json::to_value(&request)
907 .map_err(|e| anyhow!(e))?,
908 },
909 )
910 .await?;
911
912 let mut mapper = OpenAiEventMapper::new();
913 Ok(map_cloud_completion_events(
914 Box::pin(
915 response_lines(response, includes_status_messages)
916 .chain(usage_updated_event(usage))
917 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
918 ),
919 &provider_name,
920 move |event| mapper.map_event(event),
921 ))
922 });
923 async move { Ok(future.await?.boxed()) }.boxed()
924 }
925 }
926 cloud_llm_client::LanguageModelProvider::XAi => {
927 let client = self.client.clone();
928 let request = into_open_ai(
929 request,
930 &self.model.id.0,
931 self.model.supports_parallel_tool_calls,
932 false,
933 None,
934 None,
935 );
936 let llm_api_token = self.llm_api_token.clone();
937 let future = self.request_limiter.stream(async move {
938 let PerformLlmCompletionResponse {
939 response,
940 usage,
941 includes_status_messages,
942 tool_use_limit_reached,
943 } = Self::perform_llm_completion(
944 client.clone(),
945 llm_api_token,
946 app_version,
947 CompletionBody {
948 thread_id,
949 prompt_id,
950 intent,
951 mode,
952 provider: cloud_llm_client::LanguageModelProvider::XAi,
953 model: request.model.clone(),
954 provider_request: serde_json::to_value(&request)
955 .map_err(|e| anyhow!(e))?,
956 },
957 )
958 .await?;
959
960 let mut mapper = OpenAiEventMapper::new();
961 Ok(map_cloud_completion_events(
962 Box::pin(
963 response_lines(response, includes_status_messages)
964 .chain(usage_updated_event(usage))
965 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
966 ),
967 &provider_name,
968 move |event| mapper.map_event(event),
969 ))
970 });
971 async move { Ok(future.await?.boxed()) }.boxed()
972 }
973 cloud_llm_client::LanguageModelProvider::Google => {
974 let client = self.client.clone();
975 let request =
976 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
977 let llm_api_token = self.llm_api_token.clone();
978 let future = self.request_limiter.stream(async move {
979 let PerformLlmCompletionResponse {
980 response,
981 usage,
982 includes_status_messages,
983 tool_use_limit_reached,
984 } = Self::perform_llm_completion(
985 client.clone(),
986 llm_api_token,
987 app_version,
988 CompletionBody {
989 thread_id,
990 prompt_id,
991 intent,
992 mode,
993 provider: cloud_llm_client::LanguageModelProvider::Google,
994 model: request.model.model_id.clone(),
995 provider_request: serde_json::to_value(&request)
996 .map_err(|e| anyhow!(e))?,
997 },
998 )
999 .await?;
1000
1001 let mut mapper = GoogleEventMapper::new();
1002 Ok(map_cloud_completion_events(
1003 Box::pin(
1004 response_lines(response, includes_status_messages)
1005 .chain(usage_updated_event(usage))
1006 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
1007 ),
1008 &provider_name,
1009 move |event| mapper.map_event(event),
1010 ))
1011 });
1012 async move { Ok(future.await?.boxed()) }.boxed()
1013 }
1014 }
1015 }
1016}
1017
1018fn map_cloud_completion_events<T, F>(
1019 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
1020 provider: &LanguageModelProviderName,
1021 mut map_callback: F,
1022) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1023where
1024 T: DeserializeOwned + 'static,
1025 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1026 + Send
1027 + 'static,
1028{
1029 let provider = provider.clone();
1030 stream
1031 .flat_map(move |event| {
1032 futures::stream::iter(match event {
1033 Err(error) => {
1034 vec![Err(LanguageModelCompletionError::from(error))]
1035 }
1036 Ok(CompletionEvent::Status(event)) => {
1037 vec![
1038 LanguageModelCompletionEvent::from_completion_request_status(
1039 event,
1040 provider.clone(),
1041 ),
1042 ]
1043 }
1044 Ok(CompletionEvent::Event(event)) => map_callback(event),
1045 })
1046 })
1047 .boxed()
1048}
1049
1050fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1051 match provider {
1052 cloud_llm_client::LanguageModelProvider::Anthropic => {
1053 language_model::ANTHROPIC_PROVIDER_NAME
1054 }
1055 cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1056 cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1057 cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1058 }
1059}
1060
1061fn usage_updated_event<T>(
1062 usage: Option<ModelRequestUsage>,
1063) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1064 futures::stream::iter(usage.map(|usage| {
1065 Ok(CompletionEvent::Status(
1066 CompletionRequestStatus::UsageUpdated {
1067 amount: usage.amount as usize,
1068 limit: usage.limit,
1069 },
1070 ))
1071 }))
1072}
1073
1074fn tool_use_limit_reached_event<T>(
1075 tool_use_limit_reached: bool,
1076) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1077 futures::stream::iter(tool_use_limit_reached.then(|| {
1078 Ok(CompletionEvent::Status(
1079 CompletionRequestStatus::ToolUseLimitReached,
1080 ))
1081 }))
1082}
1083
1084fn response_lines<T: DeserializeOwned>(
1085 response: Response<AsyncBody>,
1086 includes_status_messages: bool,
1087) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1088 futures::stream::try_unfold(
1089 (String::new(), BufReader::new(response.into_body())),
1090 move |(mut line, mut body)| async move {
1091 match body.read_line(&mut line).await {
1092 Ok(0) => Ok(None),
1093 Ok(_) => {
1094 let event = if includes_status_messages {
1095 serde_json::from_str::<CompletionEvent<T>>(&line)?
1096 } else {
1097 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1098 };
1099
1100 line.clear();
1101 Ok(Some((event, (line, body))))
1102 }
1103 Err(e) => Err(e.into()),
1104 }
1105 },
1106 )
1107}
1108
1109#[derive(IntoElement, RegisterComponent)]
1110struct ZedAiConfiguration {
1111 is_connected: bool,
1112 plan: Option<Plan>,
1113 subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1114 eligible_for_trial: bool,
1115 account_too_young: bool,
1116 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1117}
1118
1119impl RenderOnce for ZedAiConfiguration {
1120 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1121 let is_pro = self.plan.is_some_and(|plan| {
1122 matches!(plan, Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro))
1123 });
1124 let subscription_text = match (self.plan, self.subscription_period) {
1125 (Some(Plan::V1(PlanV1::ZedPro) | Plan::V2(PlanV2::ZedPro)), Some(_)) => {
1126 "You have access to Zed's hosted models through your Pro subscription."
1127 }
1128 (Some(Plan::V1(PlanV1::ZedProTrial) | Plan::V2(PlanV2::ZedProTrial)), Some(_)) => {
1129 "You have access to Zed's hosted models through your Pro trial."
1130 }
1131 (Some(Plan::V1(PlanV1::ZedFree)), Some(_)) => {
1132 "You have basic access to Zed's hosted models through the Free plan."
1133 }
1134 (Some(Plan::V2(PlanV2::ZedFree)), Some(_)) => {
1135 if self.eligible_for_trial {
1136 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1137 } else {
1138 "Subscribe for access to Zed's hosted models."
1139 }
1140 }
1141 _ => {
1142 if self.eligible_for_trial {
1143 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1144 } else {
1145 "Subscribe for access to Zed's hosted models."
1146 }
1147 }
1148 };
1149
1150 let manage_subscription_buttons = if is_pro {
1151 Button::new("manage_settings", "Manage Subscription")
1152 .full_width()
1153 .style(ButtonStyle::Tinted(TintColor::Accent))
1154 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1155 .into_any_element()
1156 } else if self.plan.is_none() || self.eligible_for_trial {
1157 Button::new("start_trial", "Start 14-day Free Pro Trial")
1158 .full_width()
1159 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1160 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1161 .into_any_element()
1162 } else {
1163 Button::new("upgrade", "Upgrade to Pro")
1164 .full_width()
1165 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1166 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1167 .into_any_element()
1168 };
1169
1170 if !self.is_connected {
1171 return v_flex()
1172 .gap_2()
1173 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1174 .child(
1175 Button::new("sign_in", "Sign In to use Zed AI")
1176 .icon_color(Color::Muted)
1177 .icon(IconName::Github)
1178 .icon_size(IconSize::Small)
1179 .icon_position(IconPosition::Start)
1180 .full_width()
1181 .on_click({
1182 let callback = self.sign_in_callback.clone();
1183 move |_, window, cx| (callback)(window, cx)
1184 }),
1185 );
1186 }
1187
1188 v_flex().gap_2().w_full().map(|this| {
1189 if self.account_too_young {
1190 this.child(YoungAccountBanner).child(
1191 Button::new("upgrade", "Upgrade to Pro")
1192 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1193 .full_width()
1194 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1195 )
1196 } else {
1197 this.text_sm()
1198 .child(subscription_text)
1199 .child(manage_subscription_buttons)
1200 }
1201 })
1202 }
1203}
1204
1205struct ConfigurationView {
1206 state: Entity<State>,
1207 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1208}
1209
1210impl ConfigurationView {
1211 fn new(state: Entity<State>) -> Self {
1212 let sign_in_callback = Arc::new({
1213 let state = state.clone();
1214 move |_window: &mut Window, cx: &mut App| {
1215 state.update(cx, |state, cx| {
1216 state.authenticate(cx).detach_and_log_err(cx);
1217 });
1218 }
1219 });
1220
1221 Self {
1222 state,
1223 sign_in_callback,
1224 }
1225 }
1226}
1227
1228impl Render for ConfigurationView {
1229 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1230 let state = self.state.read(cx);
1231 let user_store = state.user_store.read(cx);
1232
1233 ZedAiConfiguration {
1234 is_connected: !state.is_signed_out(cx),
1235 plan: user_store.plan(),
1236 subscription_period: user_store.subscription_period(),
1237 eligible_for_trial: user_store.trial_started_at().is_none(),
1238 account_too_young: user_store.account_too_young(),
1239 sign_in_callback: self.sign_in_callback.clone(),
1240 }
1241 }
1242}
1243
1244impl Component for ZedAiConfiguration {
1245 fn name() -> &'static str {
1246 "AI Configuration Content"
1247 }
1248
1249 fn sort_name() -> &'static str {
1250 "AI Configuration Content"
1251 }
1252
1253 fn scope() -> ComponentScope {
1254 ComponentScope::Onboarding
1255 }
1256
1257 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1258 fn configuration(
1259 is_connected: bool,
1260 plan: Option<Plan>,
1261 eligible_for_trial: bool,
1262 account_too_young: bool,
1263 ) -> AnyElement {
1264 ZedAiConfiguration {
1265 is_connected,
1266 plan,
1267 subscription_period: plan
1268 .is_some()
1269 .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1270 eligible_for_trial,
1271 account_too_young,
1272 sign_in_callback: Arc::new(|_, _| {}),
1273 }
1274 .into_any_element()
1275 }
1276
1277 Some(
1278 v_flex()
1279 .p_4()
1280 .gap_4()
1281 .children(vec![
1282 single_example("Not connected", configuration(false, None, false, false)),
1283 single_example(
1284 "Accept Terms of Service",
1285 configuration(true, None, true, false),
1286 ),
1287 single_example(
1288 "No Plan - Not eligible for trial",
1289 configuration(true, None, false, false),
1290 ),
1291 single_example(
1292 "No Plan - Eligible for trial",
1293 configuration(true, None, true, false),
1294 ),
1295 single_example(
1296 "Free Plan",
1297 configuration(true, Some(Plan::V1(PlanV1::ZedFree)), true, false),
1298 ),
1299 single_example(
1300 "Zed Pro Trial Plan",
1301 configuration(true, Some(Plan::V1(PlanV1::ZedProTrial)), true, false),
1302 ),
1303 single_example(
1304 "Zed Pro Plan",
1305 configuration(true, Some(Plan::V1(PlanV1::ZedPro)), true, false),
1306 ),
1307 ])
1308 .into_any_element(),
1309 )
1310 }
1311}
1312
1313#[cfg(test)]
1314mod tests {
1315 use super::*;
1316 use http_client::http::{HeaderMap, StatusCode};
1317 use language_model::LanguageModelCompletionError;
1318
1319 #[test]
1320 fn test_api_error_conversion_with_upstream_http_error() {
1321 // upstream_http_error with 503 status should become ServerOverloaded
1322 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1323
1324 let api_error = ApiError {
1325 status: StatusCode::INTERNAL_SERVER_ERROR,
1326 body: error_body.to_string(),
1327 headers: HeaderMap::new(),
1328 };
1329
1330 let completion_error: LanguageModelCompletionError = api_error.into();
1331
1332 match completion_error {
1333 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1334 assert_eq!(
1335 message,
1336 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1337 );
1338 }
1339 _ => panic!(
1340 "Expected UpstreamProviderError for upstream 503, got: {:?}",
1341 completion_error
1342 ),
1343 }
1344
1345 // upstream_http_error with 500 status should become ApiInternalServerError
1346 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1347
1348 let api_error = ApiError {
1349 status: StatusCode::INTERNAL_SERVER_ERROR,
1350 body: error_body.to_string(),
1351 headers: HeaderMap::new(),
1352 };
1353
1354 let completion_error: LanguageModelCompletionError = api_error.into();
1355
1356 match completion_error {
1357 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1358 assert_eq!(
1359 message,
1360 "Received an error from the OpenAI API: internal server error"
1361 );
1362 }
1363 _ => panic!(
1364 "Expected UpstreamProviderError for upstream 500, got: {:?}",
1365 completion_error
1366 ),
1367 }
1368
1369 // upstream_http_error with 429 status should become RateLimitExceeded
1370 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1371
1372 let api_error = ApiError {
1373 status: StatusCode::INTERNAL_SERVER_ERROR,
1374 body: error_body.to_string(),
1375 headers: HeaderMap::new(),
1376 };
1377
1378 let completion_error: LanguageModelCompletionError = api_error.into();
1379
1380 match completion_error {
1381 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1382 assert_eq!(
1383 message,
1384 "Received an error from the Google API: rate limit exceeded"
1385 );
1386 }
1387 _ => panic!(
1388 "Expected UpstreamProviderError for upstream 429, got: {:?}",
1389 completion_error
1390 ),
1391 }
1392
1393 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1394 let error_body = "Regular internal server error";
1395
1396 let api_error = ApiError {
1397 status: StatusCode::INTERNAL_SERVER_ERROR,
1398 body: error_body.to_string(),
1399 headers: HeaderMap::new(),
1400 };
1401
1402 let completion_error: LanguageModelCompletionError = api_error.into();
1403
1404 match completion_error {
1405 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1406 assert_eq!(provider, PROVIDER_NAME);
1407 assert_eq!(message, "Regular internal server error");
1408 }
1409 _ => panic!(
1410 "Expected ApiInternalServerError for regular 500, got: {:?}",
1411 completion_error
1412 ),
1413 }
1414
1415 // upstream_http_429 format should be converted to UpstreamProviderError
1416 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1417
1418 let api_error = ApiError {
1419 status: StatusCode::INTERNAL_SERVER_ERROR,
1420 body: error_body.to_string(),
1421 headers: HeaderMap::new(),
1422 };
1423
1424 let completion_error: LanguageModelCompletionError = api_error.into();
1425
1426 match completion_error {
1427 LanguageModelCompletionError::UpstreamProviderError {
1428 message,
1429 status,
1430 retry_after,
1431 } => {
1432 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1433 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1434 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1435 }
1436 _ => panic!(
1437 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1438 completion_error
1439 ),
1440 }
1441
1442 // Invalid JSON in error body should fall back to regular error handling
1443 let error_body = "Not JSON at all";
1444
1445 let api_error = ApiError {
1446 status: StatusCode::INTERNAL_SERVER_ERROR,
1447 body: error_body.to_string(),
1448 headers: HeaderMap::new(),
1449 };
1450
1451 let completion_error: LanguageModelCompletionError = api_error.into();
1452
1453 match completion_error {
1454 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1455 assert_eq!(provider, PROVIDER_NAME);
1456 }
1457 _ => panic!(
1458 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1459 completion_error
1460 ),
1461 }
1462 }
1463}