1use ai_onboarding::YoungAccountBanner;
2use anthropic::AnthropicModelMode;
3use anyhow::{Context as _, Result, anyhow};
4use chrono::{DateTime, Utc};
5use client::{Client, CloudUserStore, ModelRequestUsage, UserStore, zed_urls};
6use cloud_llm_client::{
7 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
8 CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
9 EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE, Plan,
10 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
11 TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
12};
13use futures::{
14 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
15};
16use google_ai::GoogleModelMode;
17use gpui::{
18 AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
19};
20use http_client::http::{HeaderMap, HeaderValue};
21use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
22use language_model::{
23 AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
24 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
25 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
26 LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
27 LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
28 ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
29};
30use release_channel::AppVersion;
31use schemars::JsonSchema;
32use serde::{Deserialize, Serialize, de::DeserializeOwned};
33use settings::SettingsStore;
34use smol::io::{AsyncReadExt, BufReader};
35use std::pin::Pin;
36use std::str::FromStr as _;
37use std::sync::Arc;
38use std::time::Duration;
39use thiserror::Error;
40use ui::{TintColor, prelude::*};
41use util::{ResultExt as _, maybe};
42
43use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
44use crate::provider::google::{GoogleEventMapper, into_google};
45use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
46
47const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
48const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
49
50#[derive(Default, Clone, Debug, PartialEq)]
51pub struct ZedDotDevSettings {
52 pub available_models: Vec<AvailableModel>,
53}
54
55#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
56#[serde(rename_all = "lowercase")]
57pub enum AvailableProvider {
58 Anthropic,
59 OpenAi,
60 Google,
61}
62
63#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
64pub struct AvailableModel {
65 /// The provider of the language model.
66 pub provider: AvailableProvider,
67 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
68 pub name: String,
69 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
70 pub display_name: Option<String>,
71 /// The size of the context window, indicating the maximum number of tokens the model can process.
72 pub max_tokens: usize,
73 /// The maximum number of output tokens allowed by the model.
74 pub max_output_tokens: Option<u64>,
75 /// The maximum number of completion tokens allowed by the model (o1-* only)
76 pub max_completion_tokens: Option<u64>,
77 /// Override this model with a different Anthropic model for tool calls.
78 pub tool_override: Option<String>,
79 /// Indicates whether this custom model supports caching.
80 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
81 /// The default temperature to use for this model.
82 pub default_temperature: Option<f32>,
83 /// Any extra beta headers to provide when using the model.
84 #[serde(default)]
85 pub extra_beta_headers: Vec<String>,
86 /// The model's mode (e.g. thinking)
87 pub mode: Option<ModelMode>,
88}
89
90#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
91#[serde(tag = "type", rename_all = "lowercase")]
92pub enum ModelMode {
93 #[default]
94 Default,
95 Thinking {
96 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
97 budget_tokens: Option<u32>,
98 },
99}
100
101impl From<ModelMode> for AnthropicModelMode {
102 fn from(value: ModelMode) -> Self {
103 match value {
104 ModelMode::Default => AnthropicModelMode::Default,
105 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
106 }
107 }
108}
109
110pub struct CloudLanguageModelProvider {
111 client: Arc<Client>,
112 state: gpui::Entity<State>,
113 _maintain_client_status: Task<()>,
114}
115
116pub struct State {
117 client: Arc<Client>,
118 llm_api_token: LlmApiToken,
119 user_store: Entity<UserStore>,
120 cloud_user_store: Entity<CloudUserStore>,
121 status: client::Status,
122 accept_terms_of_service_task: Option<Task<Result<()>>>,
123 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
124 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
125 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
126 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
127 _fetch_models_task: Task<()>,
128 _settings_subscription: Subscription,
129 _llm_token_subscription: Subscription,
130}
131
132impl State {
133 fn new(
134 client: Arc<Client>,
135 user_store: Entity<UserStore>,
136 cloud_user_store: Entity<CloudUserStore>,
137 status: client::Status,
138 cx: &mut Context<Self>,
139 ) -> Self {
140 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
141
142 Self {
143 client: client.clone(),
144 llm_api_token: LlmApiToken::default(),
145 user_store,
146 cloud_user_store,
147 status,
148 accept_terms_of_service_task: None,
149 models: Vec::new(),
150 default_model: None,
151 default_fast_model: None,
152 recommended_models: Vec::new(),
153 _fetch_models_task: cx.spawn(async move |this, cx| {
154 maybe!(async move {
155 let (client, cloud_user_store, llm_api_token) =
156 this.read_with(cx, |this, _cx| {
157 (
158 client.clone(),
159 this.cloud_user_store.clone(),
160 this.llm_api_token.clone(),
161 )
162 })?;
163
164 loop {
165 let is_authenticated =
166 cloud_user_store.read_with(cx, |this, _cx| this.is_authenticated())?;
167 if is_authenticated {
168 break;
169 }
170
171 cx.background_executor()
172 .timer(Duration::from_millis(100))
173 .await;
174 }
175
176 let response = Self::fetch_models(client, llm_api_token).await?;
177 this.update(cx, |this, cx| {
178 this.update_models(response, cx);
179 })
180 })
181 .await
182 .context("failed to fetch Zed models")
183 .log_err();
184 }),
185 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
186 cx.notify();
187 }),
188 _llm_token_subscription: cx.subscribe(
189 &refresh_llm_token_listener,
190 move |this, _listener, _event, cx| {
191 let client = this.client.clone();
192 let llm_api_token = this.llm_api_token.clone();
193 cx.spawn(async move |this, cx| {
194 llm_api_token.refresh(&client).await?;
195 let response = Self::fetch_models(client, llm_api_token).await?;
196 this.update(cx, |this, cx| {
197 this.update_models(response, cx);
198 })
199 })
200 .detach_and_log_err(cx);
201 },
202 ),
203 }
204 }
205
206 fn is_signed_out(&self, cx: &App) -> bool {
207 !self.cloud_user_store.read(cx).is_authenticated()
208 }
209
210 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
211 let client = self.client.clone();
212 cx.spawn(async move |state, cx| {
213 client
214 .authenticate_and_connect(true, &cx)
215 .await
216 .into_response()?;
217 state.update(cx, |_, cx| cx.notify())
218 })
219 }
220
221 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
222 self.cloud_user_store.read(cx).has_accepted_tos()
223 }
224
225 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
226 let user_store = self.user_store.clone();
227 self.accept_terms_of_service_task = Some(cx.spawn(async move |this, cx| {
228 let _ = user_store
229 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
230 .await;
231 this.update(cx, |this, cx| {
232 this.accept_terms_of_service_task = None;
233 cx.notify()
234 })
235 }));
236 }
237
238 fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
239 let mut models = Vec::new();
240
241 for model in response.models {
242 models.push(Arc::new(model.clone()));
243
244 // Right now we represent thinking variants of models as separate models on the client,
245 // so we need to insert variants for any model that supports thinking.
246 if model.supports_thinking {
247 models.push(Arc::new(cloud_llm_client::LanguageModel {
248 id: cloud_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
249 display_name: format!("{} Thinking", model.display_name),
250 ..model
251 }));
252 }
253 }
254
255 self.default_model = models
256 .iter()
257 .find(|model| model.id == response.default_model)
258 .cloned();
259 self.default_fast_model = models
260 .iter()
261 .find(|model| model.id == response.default_fast_model)
262 .cloned();
263 self.recommended_models = response
264 .recommended_models
265 .iter()
266 .filter_map(|id| models.iter().find(|model| &model.id == id))
267 .cloned()
268 .collect();
269 self.models = models;
270 cx.notify();
271 }
272
273 async fn fetch_models(
274 client: Arc<Client>,
275 llm_api_token: LlmApiToken,
276 ) -> Result<ListModelsResponse> {
277 let http_client = &client.http_client();
278 let token = llm_api_token.acquire(&client).await?;
279
280 let request = http_client::Request::builder()
281 .method(Method::GET)
282 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
283 .header("Authorization", format!("Bearer {token}"))
284 .body(AsyncBody::empty())?;
285 let mut response = http_client
286 .send(request)
287 .await
288 .context("failed to send list models request")?;
289
290 if response.status().is_success() {
291 let mut body = String::new();
292 response.body_mut().read_to_string(&mut body).await?;
293 return Ok(serde_json::from_str(&body)?);
294 } else {
295 let mut body = String::new();
296 response.body_mut().read_to_string(&mut body).await?;
297 anyhow::bail!(
298 "error listing models.\nStatus: {:?}\nBody: {body}",
299 response.status(),
300 );
301 }
302 }
303}
304
305impl CloudLanguageModelProvider {
306 pub fn new(
307 user_store: Entity<UserStore>,
308 cloud_user_store: Entity<CloudUserStore>,
309 client: Arc<Client>,
310 cx: &mut App,
311 ) -> Self {
312 let mut status_rx = client.status();
313 let status = *status_rx.borrow();
314
315 let state = cx.new(|cx| {
316 State::new(
317 client.clone(),
318 user_store.clone(),
319 cloud_user_store.clone(),
320 status,
321 cx,
322 )
323 });
324
325 let state_ref = state.downgrade();
326 let maintain_client_status = cx.spawn(async move |cx| {
327 while let Some(status) = status_rx.next().await {
328 if let Some(this) = state_ref.upgrade() {
329 _ = this.update(cx, |this, cx| {
330 if this.status != status {
331 this.status = status;
332 cx.notify();
333 }
334 });
335 } else {
336 break;
337 }
338 }
339 });
340
341 Self {
342 client,
343 state: state.clone(),
344 _maintain_client_status: maintain_client_status,
345 }
346 }
347
348 fn create_language_model(
349 &self,
350 model: Arc<cloud_llm_client::LanguageModel>,
351 llm_api_token: LlmApiToken,
352 ) -> Arc<dyn LanguageModel> {
353 Arc::new(CloudLanguageModel {
354 id: LanguageModelId(SharedString::from(model.id.0.clone())),
355 model,
356 llm_api_token: llm_api_token.clone(),
357 client: self.client.clone(),
358 request_limiter: RateLimiter::new(4),
359 })
360 }
361}
362
363impl LanguageModelProviderState for CloudLanguageModelProvider {
364 type ObservableEntity = State;
365
366 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
367 Some(self.state.clone())
368 }
369}
370
371impl LanguageModelProvider for CloudLanguageModelProvider {
372 fn id(&self) -> LanguageModelProviderId {
373 PROVIDER_ID
374 }
375
376 fn name(&self) -> LanguageModelProviderName {
377 PROVIDER_NAME
378 }
379
380 fn icon(&self) -> IconName {
381 IconName::AiZed
382 }
383
384 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
385 let default_model = self.state.read(cx).default_model.clone()?;
386 let llm_api_token = self.state.read(cx).llm_api_token.clone();
387 Some(self.create_language_model(default_model, llm_api_token))
388 }
389
390 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
391 let default_fast_model = self.state.read(cx).default_fast_model.clone()?;
392 let llm_api_token = self.state.read(cx).llm_api_token.clone();
393 Some(self.create_language_model(default_fast_model, llm_api_token))
394 }
395
396 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
397 let llm_api_token = self.state.read(cx).llm_api_token.clone();
398 self.state
399 .read(cx)
400 .recommended_models
401 .iter()
402 .cloned()
403 .map(|model| self.create_language_model(model, llm_api_token.clone()))
404 .collect()
405 }
406
407 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
408 let llm_api_token = self.state.read(cx).llm_api_token.clone();
409 self.state
410 .read(cx)
411 .models
412 .iter()
413 .cloned()
414 .map(|model| self.create_language_model(model, llm_api_token.clone()))
415 .collect()
416 }
417
418 fn is_authenticated(&self, cx: &App) -> bool {
419 let state = self.state.read(cx);
420 !state.is_signed_out(cx) && state.has_accepted_terms_of_service(cx)
421 }
422
423 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
424 Task::ready(Ok(()))
425 }
426
427 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
428 cx.new(|_| ConfigurationView::new(self.state.clone()))
429 .into()
430 }
431
432 fn must_accept_terms(&self, cx: &App) -> bool {
433 !self.state.read(cx).has_accepted_terms_of_service(cx)
434 }
435
436 fn render_accept_terms(
437 &self,
438 view: LanguageModelProviderTosView,
439 cx: &mut App,
440 ) -> Option<AnyElement> {
441 let state = self.state.read(cx);
442 if state.has_accepted_terms_of_service(cx) {
443 return None;
444 }
445 Some(
446 render_accept_terms(view, state.accept_terms_of_service_task.is_some(), {
447 let state = self.state.clone();
448 move |_window, cx| {
449 state.update(cx, |state, cx| state.accept_terms_of_service(cx));
450 }
451 })
452 .into_any_element(),
453 )
454 }
455
456 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
457 Task::ready(Ok(()))
458 }
459}
460
461fn render_accept_terms(
462 view_kind: LanguageModelProviderTosView,
463 accept_terms_of_service_in_progress: bool,
464 accept_terms_callback: impl Fn(&mut Window, &mut App) + 'static,
465) -> impl IntoElement {
466 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
467 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadEmptyState);
468
469 let terms_button = Button::new("terms_of_service", "Terms of Service")
470 .style(ButtonStyle::Subtle)
471 .icon(IconName::ArrowUpRight)
472 .icon_color(Color::Muted)
473 .icon_size(IconSize::XSmall)
474 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
475 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
476
477 let button_container = h_flex().child(
478 Button::new("accept_terms", "I accept the Terms of Service")
479 .when(!thread_empty_state, |this| {
480 this.full_width()
481 .style(ButtonStyle::Tinted(TintColor::Accent))
482 .icon(IconName::Check)
483 .icon_position(IconPosition::Start)
484 .icon_size(IconSize::Small)
485 })
486 .when(thread_empty_state, |this| {
487 this.style(ButtonStyle::Tinted(TintColor::Warning))
488 .label_size(LabelSize::Small)
489 })
490 .disabled(accept_terms_of_service_in_progress)
491 .on_click(move |_, window, cx| (accept_terms_callback)(window, cx)),
492 );
493
494 if thread_empty_state {
495 h_flex()
496 .w_full()
497 .flex_wrap()
498 .justify_between()
499 .child(
500 h_flex()
501 .child(
502 Label::new("To start using Zed AI, please read and accept the")
503 .size(LabelSize::Small),
504 )
505 .child(terms_button),
506 )
507 .child(button_container)
508 } else {
509 v_flex()
510 .w_full()
511 .gap_2()
512 .child(
513 h_flex()
514 .flex_wrap()
515 .when(thread_fresh_start, |this| this.justify_center())
516 .child(Label::new(
517 "To start using Zed AI, please read and accept the",
518 ))
519 .child(terms_button),
520 )
521 .child({
522 match view_kind {
523 LanguageModelProviderTosView::TextThreadPopup => {
524 button_container.w_full().justify_end()
525 }
526 LanguageModelProviderTosView::Configuration => {
527 button_container.w_full().justify_start()
528 }
529 LanguageModelProviderTosView::ThreadFreshStart => {
530 button_container.w_full().justify_center()
531 }
532 LanguageModelProviderTosView::ThreadEmptyState => div().w_0(),
533 }
534 })
535 }
536}
537
538pub struct CloudLanguageModel {
539 id: LanguageModelId,
540 model: Arc<cloud_llm_client::LanguageModel>,
541 llm_api_token: LlmApiToken,
542 client: Arc<Client>,
543 request_limiter: RateLimiter,
544}
545
546struct PerformLlmCompletionResponse {
547 response: Response<AsyncBody>,
548 usage: Option<ModelRequestUsage>,
549 tool_use_limit_reached: bool,
550 includes_status_messages: bool,
551}
552
553impl CloudLanguageModel {
554 async fn perform_llm_completion(
555 client: Arc<Client>,
556 llm_api_token: LlmApiToken,
557 app_version: Option<SemanticVersion>,
558 body: CompletionBody,
559 ) -> Result<PerformLlmCompletionResponse> {
560 let http_client = &client.http_client();
561
562 let mut token = llm_api_token.acquire(&client).await?;
563 let mut refreshed_token = false;
564
565 loop {
566 let request_builder = http_client::Request::builder()
567 .method(Method::POST)
568 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref());
569 let request_builder = if let Some(app_version) = app_version {
570 request_builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
571 } else {
572 request_builder
573 };
574
575 let request = request_builder
576 .header("Content-Type", "application/json")
577 .header("Authorization", format!("Bearer {token}"))
578 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
579 .body(serde_json::to_string(&body)?.into())?;
580 let mut response = http_client.send(request).await?;
581 let status = response.status();
582 if status.is_success() {
583 let includes_status_messages = response
584 .headers()
585 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
586 .is_some();
587
588 let tool_use_limit_reached = response
589 .headers()
590 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
591 .is_some();
592
593 let usage = if includes_status_messages {
594 None
595 } else {
596 ModelRequestUsage::from_headers(response.headers()).ok()
597 };
598
599 return Ok(PerformLlmCompletionResponse {
600 response,
601 usage,
602 includes_status_messages,
603 tool_use_limit_reached,
604 });
605 }
606
607 if !refreshed_token
608 && response
609 .headers()
610 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
611 .is_some()
612 {
613 token = llm_api_token.refresh(&client).await?;
614 refreshed_token = true;
615 continue;
616 }
617
618 if status == StatusCode::FORBIDDEN
619 && response
620 .headers()
621 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
622 .is_some()
623 {
624 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
625 .headers()
626 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
627 .and_then(|resource| resource.to_str().ok())
628 {
629 if let Some(plan) = response
630 .headers()
631 .get(CURRENT_PLAN_HEADER_NAME)
632 .and_then(|plan| plan.to_str().ok())
633 .and_then(|plan| cloud_llm_client::Plan::from_str(plan).ok())
634 {
635 let plan = match plan {
636 cloud_llm_client::Plan::ZedFree => proto::Plan::Free,
637 cloud_llm_client::Plan::ZedPro => proto::Plan::ZedPro,
638 cloud_llm_client::Plan::ZedProTrial => proto::Plan::ZedProTrial,
639 };
640 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
641 }
642 }
643 } else if status == StatusCode::PAYMENT_REQUIRED {
644 return Err(anyhow!(PaymentRequiredError));
645 }
646
647 let mut body = String::new();
648 let headers = response.headers().clone();
649 response.body_mut().read_to_string(&mut body).await?;
650 return Err(anyhow!(ApiError {
651 status,
652 body,
653 headers
654 }));
655 }
656 }
657}
658
659#[derive(Debug, Error)]
660#[error("cloud language model request failed with status {status}: {body}")]
661struct ApiError {
662 status: StatusCode,
663 body: String,
664 headers: HeaderMap<HeaderValue>,
665}
666
667/// Represents error responses from Zed's cloud API.
668///
669/// Example JSON for an upstream HTTP error:
670/// ```json
671/// {
672/// "code": "upstream_http_error",
673/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
674/// "upstream_status": 503
675/// }
676/// ```
677#[derive(Debug, serde::Deserialize)]
678struct CloudApiError {
679 code: String,
680 message: String,
681 #[serde(default)]
682 #[serde(deserialize_with = "deserialize_optional_status_code")]
683 upstream_status: Option<StatusCode>,
684 #[serde(default)]
685 retry_after: Option<f64>,
686}
687
688fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
689where
690 D: serde::Deserializer<'de>,
691{
692 let opt: Option<u16> = Option::deserialize(deserializer)?;
693 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
694}
695
696impl From<ApiError> for LanguageModelCompletionError {
697 fn from(error: ApiError) -> Self {
698 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
699 if cloud_error.code.starts_with("upstream_http_") {
700 let status = if let Some(status) = cloud_error.upstream_status {
701 status
702 } else if cloud_error.code.ends_with("_error") {
703 error.status
704 } else {
705 // If there's a status code in the code string (e.g. "upstream_http_429")
706 // then use that; otherwise, see if the JSON contains a status code.
707 cloud_error
708 .code
709 .strip_prefix("upstream_http_")
710 .and_then(|code_str| code_str.parse::<u16>().ok())
711 .and_then(|code| StatusCode::from_u16(code).ok())
712 .unwrap_or(error.status)
713 };
714
715 return LanguageModelCompletionError::UpstreamProviderError {
716 message: cloud_error.message,
717 status,
718 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
719 };
720 }
721 }
722
723 let retry_after = None;
724 LanguageModelCompletionError::from_http_status(
725 PROVIDER_NAME,
726 error.status,
727 error.body,
728 retry_after,
729 )
730 }
731}
732
733impl LanguageModel for CloudLanguageModel {
734 fn id(&self) -> LanguageModelId {
735 self.id.clone()
736 }
737
738 fn name(&self) -> LanguageModelName {
739 LanguageModelName::from(self.model.display_name.clone())
740 }
741
742 fn provider_id(&self) -> LanguageModelProviderId {
743 PROVIDER_ID
744 }
745
746 fn provider_name(&self) -> LanguageModelProviderName {
747 PROVIDER_NAME
748 }
749
750 fn upstream_provider_id(&self) -> LanguageModelProviderId {
751 use cloud_llm_client::LanguageModelProvider::*;
752 match self.model.provider {
753 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
754 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
755 Google => language_model::GOOGLE_PROVIDER_ID,
756 }
757 }
758
759 fn upstream_provider_name(&self) -> LanguageModelProviderName {
760 use cloud_llm_client::LanguageModelProvider::*;
761 match self.model.provider {
762 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
763 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
764 Google => language_model::GOOGLE_PROVIDER_NAME,
765 }
766 }
767
768 fn supports_tools(&self) -> bool {
769 self.model.supports_tools
770 }
771
772 fn supports_images(&self) -> bool {
773 self.model.supports_images
774 }
775
776 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
777 match choice {
778 LanguageModelToolChoice::Auto
779 | LanguageModelToolChoice::Any
780 | LanguageModelToolChoice::None => true,
781 }
782 }
783
784 fn supports_burn_mode(&self) -> bool {
785 self.model.supports_max_mode
786 }
787
788 fn telemetry_id(&self) -> String {
789 format!("zed.dev/{}", self.model.id)
790 }
791
792 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
793 match self.model.provider {
794 cloud_llm_client::LanguageModelProvider::Anthropic
795 | cloud_llm_client::LanguageModelProvider::OpenAi => {
796 LanguageModelToolSchemaFormat::JsonSchema
797 }
798 cloud_llm_client::LanguageModelProvider::Google => {
799 LanguageModelToolSchemaFormat::JsonSchemaSubset
800 }
801 }
802 }
803
804 fn max_token_count(&self) -> u64 {
805 self.model.max_token_count as u64
806 }
807
808 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
809 self.model
810 .max_token_count_in_max_mode
811 .filter(|_| self.model.supports_max_mode)
812 .map(|max_token_count| max_token_count as u64)
813 }
814
815 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
816 match &self.model.provider {
817 cloud_llm_client::LanguageModelProvider::Anthropic => {
818 Some(LanguageModelCacheConfiguration {
819 min_total_token: 2_048,
820 should_speculate: true,
821 max_cache_anchors: 4,
822 })
823 }
824 cloud_llm_client::LanguageModelProvider::OpenAi
825 | cloud_llm_client::LanguageModelProvider::Google => None,
826 }
827 }
828
829 fn count_tokens(
830 &self,
831 request: LanguageModelRequest,
832 cx: &App,
833 ) -> BoxFuture<'static, Result<u64>> {
834 match self.model.provider {
835 cloud_llm_client::LanguageModelProvider::Anthropic => {
836 count_anthropic_tokens(request, cx)
837 }
838 cloud_llm_client::LanguageModelProvider::OpenAi => {
839 let model = match open_ai::Model::from_id(&self.model.id.0) {
840 Ok(model) => model,
841 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
842 };
843 count_open_ai_tokens(request, model, cx)
844 }
845 cloud_llm_client::LanguageModelProvider::Google => {
846 let client = self.client.clone();
847 let llm_api_token = self.llm_api_token.clone();
848 let model_id = self.model.id.to_string();
849 let generate_content_request =
850 into_google(request, model_id.clone(), GoogleModelMode::Default);
851 async move {
852 let http_client = &client.http_client();
853 let token = llm_api_token.acquire(&client).await?;
854
855 let request_body = CountTokensBody {
856 provider: cloud_llm_client::LanguageModelProvider::Google,
857 model: model_id,
858 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
859 generate_content_request,
860 })?,
861 };
862 let request = http_client::Request::builder()
863 .method(Method::POST)
864 .uri(
865 http_client
866 .build_zed_llm_url("/count_tokens", &[])?
867 .as_ref(),
868 )
869 .header("Content-Type", "application/json")
870 .header("Authorization", format!("Bearer {token}"))
871 .body(serde_json::to_string(&request_body)?.into())?;
872 let mut response = http_client.send(request).await?;
873 let status = response.status();
874 let headers = response.headers().clone();
875 let mut response_body = String::new();
876 response
877 .body_mut()
878 .read_to_string(&mut response_body)
879 .await?;
880
881 if status.is_success() {
882 let response_body: CountTokensResponse =
883 serde_json::from_str(&response_body)?;
884
885 Ok(response_body.tokens as u64)
886 } else {
887 Err(anyhow!(ApiError {
888 status,
889 body: response_body,
890 headers
891 }))
892 }
893 }
894 .boxed()
895 }
896 }
897 }
898
899 fn stream_completion(
900 &self,
901 request: LanguageModelRequest,
902 cx: &AsyncApp,
903 ) -> BoxFuture<
904 'static,
905 Result<
906 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
907 LanguageModelCompletionError,
908 >,
909 > {
910 let thread_id = request.thread_id.clone();
911 let prompt_id = request.prompt_id.clone();
912 let intent = request.intent;
913 let mode = request.mode;
914 let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
915 let thinking_allowed = request.thinking_allowed;
916 match self.model.provider {
917 cloud_llm_client::LanguageModelProvider::Anthropic => {
918 let request = into_anthropic(
919 request,
920 self.model.id.to_string(),
921 1.0,
922 self.model.max_output_tokens as u64,
923 if thinking_allowed && self.model.id.0.ends_with("-thinking") {
924 AnthropicModelMode::Thinking {
925 budget_tokens: Some(4_096),
926 }
927 } else {
928 AnthropicModelMode::Default
929 },
930 );
931 let client = self.client.clone();
932 let llm_api_token = self.llm_api_token.clone();
933 let future = self.request_limiter.stream(async move {
934 let PerformLlmCompletionResponse {
935 response,
936 usage,
937 includes_status_messages,
938 tool_use_limit_reached,
939 } = Self::perform_llm_completion(
940 client.clone(),
941 llm_api_token,
942 app_version,
943 CompletionBody {
944 thread_id,
945 prompt_id,
946 intent,
947 mode,
948 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
949 model: request.model.clone(),
950 provider_request: serde_json::to_value(&request)
951 .map_err(|e| anyhow!(e))?,
952 },
953 )
954 .await
955 .map_err(|err| match err.downcast::<ApiError>() {
956 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
957 Err(err) => anyhow!(err),
958 })?;
959
960 let mut mapper = AnthropicEventMapper::new();
961 Ok(map_cloud_completion_events(
962 Box::pin(
963 response_lines(response, includes_status_messages)
964 .chain(usage_updated_event(usage))
965 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
966 ),
967 move |event| mapper.map_event(event),
968 ))
969 });
970 async move { Ok(future.await?.boxed()) }.boxed()
971 }
972 cloud_llm_client::LanguageModelProvider::OpenAi => {
973 let client = self.client.clone();
974 let model = match open_ai::Model::from_id(&self.model.id.0) {
975 Ok(model) => model,
976 Err(err) => return async move { Err(anyhow!(err).into()) }.boxed(),
977 };
978 let request = into_open_ai(
979 request,
980 model.id(),
981 model.supports_parallel_tool_calls(),
982 None,
983 );
984 let llm_api_token = self.llm_api_token.clone();
985 let future = self.request_limiter.stream(async move {
986 let PerformLlmCompletionResponse {
987 response,
988 usage,
989 includes_status_messages,
990 tool_use_limit_reached,
991 } = Self::perform_llm_completion(
992 client.clone(),
993 llm_api_token,
994 app_version,
995 CompletionBody {
996 thread_id,
997 prompt_id,
998 intent,
999 mode,
1000 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
1001 model: request.model.clone(),
1002 provider_request: serde_json::to_value(&request)
1003 .map_err(|e| anyhow!(e))?,
1004 },
1005 )
1006 .await?;
1007
1008 let mut mapper = OpenAiEventMapper::new();
1009 Ok(map_cloud_completion_events(
1010 Box::pin(
1011 response_lines(response, includes_status_messages)
1012 .chain(usage_updated_event(usage))
1013 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
1014 ),
1015 move |event| mapper.map_event(event),
1016 ))
1017 });
1018 async move { Ok(future.await?.boxed()) }.boxed()
1019 }
1020 cloud_llm_client::LanguageModelProvider::Google => {
1021 let client = self.client.clone();
1022 let request =
1023 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
1024 let llm_api_token = self.llm_api_token.clone();
1025 let future = self.request_limiter.stream(async move {
1026 let PerformLlmCompletionResponse {
1027 response,
1028 usage,
1029 includes_status_messages,
1030 tool_use_limit_reached,
1031 } = Self::perform_llm_completion(
1032 client.clone(),
1033 llm_api_token,
1034 app_version,
1035 CompletionBody {
1036 thread_id,
1037 prompt_id,
1038 intent,
1039 mode,
1040 provider: cloud_llm_client::LanguageModelProvider::Google,
1041 model: request.model.model_id.clone(),
1042 provider_request: serde_json::to_value(&request)
1043 .map_err(|e| anyhow!(e))?,
1044 },
1045 )
1046 .await?;
1047
1048 let mut mapper = GoogleEventMapper::new();
1049 Ok(map_cloud_completion_events(
1050 Box::pin(
1051 response_lines(response, includes_status_messages)
1052 .chain(usage_updated_event(usage))
1053 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
1054 ),
1055 move |event| mapper.map_event(event),
1056 ))
1057 });
1058 async move { Ok(future.await?.boxed()) }.boxed()
1059 }
1060 }
1061 }
1062}
1063
1064fn map_cloud_completion_events<T, F>(
1065 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
1066 mut map_callback: F,
1067) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1068where
1069 T: DeserializeOwned + 'static,
1070 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
1071 + Send
1072 + 'static,
1073{
1074 stream
1075 .flat_map(move |event| {
1076 futures::stream::iter(match event {
1077 Err(error) => {
1078 vec![Err(LanguageModelCompletionError::from(error))]
1079 }
1080 Ok(CompletionEvent::Status(event)) => {
1081 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
1082 }
1083 Ok(CompletionEvent::Event(event)) => map_callback(event),
1084 })
1085 })
1086 .boxed()
1087}
1088
1089fn usage_updated_event<T>(
1090 usage: Option<ModelRequestUsage>,
1091) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1092 futures::stream::iter(usage.map(|usage| {
1093 Ok(CompletionEvent::Status(
1094 CompletionRequestStatus::UsageUpdated {
1095 amount: usage.amount as usize,
1096 limit: usage.limit,
1097 },
1098 ))
1099 }))
1100}
1101
1102fn tool_use_limit_reached_event<T>(
1103 tool_use_limit_reached: bool,
1104) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1105 futures::stream::iter(tool_use_limit_reached.then(|| {
1106 Ok(CompletionEvent::Status(
1107 CompletionRequestStatus::ToolUseLimitReached,
1108 ))
1109 }))
1110}
1111
1112fn response_lines<T: DeserializeOwned>(
1113 response: Response<AsyncBody>,
1114 includes_status_messages: bool,
1115) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1116 futures::stream::try_unfold(
1117 (String::new(), BufReader::new(response.into_body())),
1118 move |(mut line, mut body)| async move {
1119 match body.read_line(&mut line).await {
1120 Ok(0) => Ok(None),
1121 Ok(_) => {
1122 let event = if includes_status_messages {
1123 serde_json::from_str::<CompletionEvent<T>>(&line)?
1124 } else {
1125 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1126 };
1127
1128 line.clear();
1129 Ok(Some((event, (line, body))))
1130 }
1131 Err(e) => Err(e.into()),
1132 }
1133 },
1134 )
1135}
1136
1137#[derive(IntoElement, RegisterComponent)]
1138struct ZedAiConfiguration {
1139 is_connected: bool,
1140 plan: Option<Plan>,
1141 subscription_period: Option<(DateTime<Utc>, DateTime<Utc>)>,
1142 eligible_for_trial: bool,
1143 has_accepted_terms_of_service: bool,
1144 account_too_young: bool,
1145 accept_terms_of_service_in_progress: bool,
1146 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1147 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1148}
1149
1150impl RenderOnce for ZedAiConfiguration {
1151 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1152 let young_account_banner = YoungAccountBanner;
1153
1154 let is_pro = self.plan == Some(Plan::ZedPro);
1155 let subscription_text = match (self.plan, self.subscription_period) {
1156 (Some(Plan::ZedPro), Some(_)) => {
1157 "You have access to Zed's hosted models through your Pro subscription."
1158 }
1159 (Some(Plan::ZedProTrial), Some(_)) => {
1160 "You have access to Zed's hosted models through your Pro trial."
1161 }
1162 (Some(Plan::ZedFree), Some(_)) => {
1163 "You have basic access to Zed's hosted models through the Free plan."
1164 }
1165 _ => {
1166 if self.eligible_for_trial {
1167 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1168 } else {
1169 "Subscribe for access to Zed's hosted models."
1170 }
1171 }
1172 };
1173
1174 let manage_subscription_buttons = if is_pro {
1175 Button::new("manage_settings", "Manage Subscription")
1176 .full_width()
1177 .style(ButtonStyle::Tinted(TintColor::Accent))
1178 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1179 .into_any_element()
1180 } else if self.plan.is_none() || self.eligible_for_trial {
1181 Button::new("start_trial", "Start 14-day Free Pro Trial")
1182 .full_width()
1183 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1184 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1185 .into_any_element()
1186 } else {
1187 Button::new("upgrade", "Upgrade to Pro")
1188 .full_width()
1189 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1190 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1191 .into_any_element()
1192 };
1193
1194 if !self.is_connected {
1195 return v_flex()
1196 .gap_2()
1197 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1198 .child(
1199 Button::new("sign_in", "Sign In to use Zed AI")
1200 .icon_color(Color::Muted)
1201 .icon(IconName::Github)
1202 .icon_size(IconSize::Small)
1203 .icon_position(IconPosition::Start)
1204 .full_width()
1205 .on_click({
1206 let callback = self.sign_in_callback.clone();
1207 move |_, window, cx| (callback)(window, cx)
1208 }),
1209 );
1210 }
1211
1212 v_flex()
1213 .gap_2()
1214 .w_full()
1215 .when(!self.has_accepted_terms_of_service, |this| {
1216 this.child(render_accept_terms(
1217 LanguageModelProviderTosView::Configuration,
1218 self.accept_terms_of_service_in_progress,
1219 {
1220 let callback = self.accept_terms_of_service_callback.clone();
1221 move |window, cx| (callback)(window, cx)
1222 },
1223 ))
1224 })
1225 .map(|this| {
1226 if self.has_accepted_terms_of_service && self.account_too_young {
1227 this.child(young_account_banner).child(
1228 Button::new("upgrade", "Upgrade to Pro")
1229 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1230 .full_width()
1231 .on_click(|_, _, cx| {
1232 cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))
1233 }),
1234 )
1235 } else if self.has_accepted_terms_of_service {
1236 this.text_sm()
1237 .child(subscription_text)
1238 .child(manage_subscription_buttons)
1239 } else {
1240 this
1241 }
1242 })
1243 .when(self.has_accepted_terms_of_service, |this| this)
1244 }
1245}
1246
1247struct ConfigurationView {
1248 state: Entity<State>,
1249 accept_terms_of_service_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1250 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1251}
1252
1253impl ConfigurationView {
1254 fn new(state: Entity<State>) -> Self {
1255 let accept_terms_of_service_callback = Arc::new({
1256 let state = state.clone();
1257 move |_window: &mut Window, cx: &mut App| {
1258 state.update(cx, |state, cx| {
1259 state.accept_terms_of_service(cx);
1260 });
1261 }
1262 });
1263
1264 let sign_in_callback = Arc::new({
1265 let state = state.clone();
1266 move |_window: &mut Window, cx: &mut App| {
1267 state.update(cx, |state, cx| {
1268 state.authenticate(cx).detach_and_log_err(cx);
1269 });
1270 }
1271 });
1272
1273 Self {
1274 state,
1275 accept_terms_of_service_callback,
1276 sign_in_callback,
1277 }
1278 }
1279}
1280
1281impl Render for ConfigurationView {
1282 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1283 let state = self.state.read(cx);
1284 let cloud_user_store = state.cloud_user_store.read(cx);
1285
1286 ZedAiConfiguration {
1287 is_connected: !state.is_signed_out(cx),
1288 plan: cloud_user_store.plan(),
1289 subscription_period: cloud_user_store.subscription_period(),
1290 eligible_for_trial: cloud_user_store.trial_started_at().is_none(),
1291 has_accepted_terms_of_service: state.has_accepted_terms_of_service(cx),
1292 account_too_young: cloud_user_store.account_too_young(),
1293 accept_terms_of_service_in_progress: state.accept_terms_of_service_task.is_some(),
1294 accept_terms_of_service_callback: self.accept_terms_of_service_callback.clone(),
1295 sign_in_callback: self.sign_in_callback.clone(),
1296 }
1297 }
1298}
1299
1300impl Component for ZedAiConfiguration {
1301 fn scope() -> ComponentScope {
1302 ComponentScope::Agent
1303 }
1304
1305 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1306 fn configuration(
1307 is_connected: bool,
1308 plan: Option<Plan>,
1309 eligible_for_trial: bool,
1310 account_too_young: bool,
1311 has_accepted_terms_of_service: bool,
1312 ) -> AnyElement {
1313 ZedAiConfiguration {
1314 is_connected,
1315 plan,
1316 subscription_period: plan
1317 .is_some()
1318 .then(|| (Utc::now(), Utc::now() + chrono::Duration::days(7))),
1319 eligible_for_trial,
1320 has_accepted_terms_of_service,
1321 account_too_young,
1322 accept_terms_of_service_in_progress: false,
1323 accept_terms_of_service_callback: Arc::new(|_, _| {}),
1324 sign_in_callback: Arc::new(|_, _| {}),
1325 }
1326 .into_any_element()
1327 }
1328
1329 Some(
1330 v_flex()
1331 .p_4()
1332 .gap_4()
1333 .children(vec![
1334 single_example(
1335 "Not connected",
1336 configuration(false, None, false, false, true),
1337 ),
1338 single_example(
1339 "Accept Terms of Service",
1340 configuration(true, None, true, false, false),
1341 ),
1342 single_example(
1343 "No Plan - Not eligible for trial",
1344 configuration(true, None, false, false, true),
1345 ),
1346 single_example(
1347 "No Plan - Eligible for trial",
1348 configuration(true, None, true, false, true),
1349 ),
1350 single_example(
1351 "Free Plan",
1352 configuration(true, Some(Plan::ZedFree), true, false, true),
1353 ),
1354 single_example(
1355 "Zed Pro Trial Plan",
1356 configuration(true, Some(Plan::ZedProTrial), true, false, true),
1357 ),
1358 single_example(
1359 "Zed Pro Plan",
1360 configuration(true, Some(Plan::ZedPro), true, false, true),
1361 ),
1362 ])
1363 .into_any_element(),
1364 )
1365 }
1366}
1367
1368#[cfg(test)]
1369mod tests {
1370 use super::*;
1371 use http_client::http::{HeaderMap, StatusCode};
1372 use language_model::LanguageModelCompletionError;
1373
1374 #[test]
1375 fn test_api_error_conversion_with_upstream_http_error() {
1376 // upstream_http_error with 503 status should become ServerOverloaded
1377 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1378
1379 let api_error = ApiError {
1380 status: StatusCode::INTERNAL_SERVER_ERROR,
1381 body: error_body.to_string(),
1382 headers: HeaderMap::new(),
1383 };
1384
1385 let completion_error: LanguageModelCompletionError = api_error.into();
1386
1387 match completion_error {
1388 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1389 assert_eq!(
1390 message,
1391 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1392 );
1393 }
1394 _ => panic!(
1395 "Expected UpstreamProviderError for upstream 503, got: {:?}",
1396 completion_error
1397 ),
1398 }
1399
1400 // upstream_http_error with 500 status should become ApiInternalServerError
1401 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1402
1403 let api_error = ApiError {
1404 status: StatusCode::INTERNAL_SERVER_ERROR,
1405 body: error_body.to_string(),
1406 headers: HeaderMap::new(),
1407 };
1408
1409 let completion_error: LanguageModelCompletionError = api_error.into();
1410
1411 match completion_error {
1412 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1413 assert_eq!(
1414 message,
1415 "Received an error from the OpenAI API: internal server error"
1416 );
1417 }
1418 _ => panic!(
1419 "Expected UpstreamProviderError for upstream 500, got: {:?}",
1420 completion_error
1421 ),
1422 }
1423
1424 // upstream_http_error with 429 status should become RateLimitExceeded
1425 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1426
1427 let api_error = ApiError {
1428 status: StatusCode::INTERNAL_SERVER_ERROR,
1429 body: error_body.to_string(),
1430 headers: HeaderMap::new(),
1431 };
1432
1433 let completion_error: LanguageModelCompletionError = api_error.into();
1434
1435 match completion_error {
1436 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1437 assert_eq!(
1438 message,
1439 "Received an error from the Google API: rate limit exceeded"
1440 );
1441 }
1442 _ => panic!(
1443 "Expected UpstreamProviderError for upstream 429, got: {:?}",
1444 completion_error
1445 ),
1446 }
1447
1448 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1449 let error_body = "Regular internal server error";
1450
1451 let api_error = ApiError {
1452 status: StatusCode::INTERNAL_SERVER_ERROR,
1453 body: error_body.to_string(),
1454 headers: HeaderMap::new(),
1455 };
1456
1457 let completion_error: LanguageModelCompletionError = api_error.into();
1458
1459 match completion_error {
1460 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1461 assert_eq!(provider, PROVIDER_NAME);
1462 assert_eq!(message, "Regular internal server error");
1463 }
1464 _ => panic!(
1465 "Expected ApiInternalServerError for regular 500, got: {:?}",
1466 completion_error
1467 ),
1468 }
1469
1470 // upstream_http_429 format should be converted to UpstreamProviderError
1471 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1472
1473 let api_error = ApiError {
1474 status: StatusCode::INTERNAL_SERVER_ERROR,
1475 body: error_body.to_string(),
1476 headers: HeaderMap::new(),
1477 };
1478
1479 let completion_error: LanguageModelCompletionError = api_error.into();
1480
1481 match completion_error {
1482 LanguageModelCompletionError::UpstreamProviderError {
1483 message,
1484 status,
1485 retry_after,
1486 } => {
1487 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1488 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1489 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1490 }
1491 _ => panic!(
1492 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1493 completion_error
1494 ),
1495 }
1496
1497 // Invalid JSON in error body should fall back to regular error handling
1498 let error_body = "Not JSON at all";
1499
1500 let api_error = ApiError {
1501 status: StatusCode::INTERNAL_SERVER_ERROR,
1502 body: error_body.to_string(),
1503 headers: HeaderMap::new(),
1504 };
1505
1506 let completion_error: LanguageModelCompletionError = api_error.into();
1507
1508 match completion_error {
1509 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1510 assert_eq!(provider, PROVIDER_NAME);
1511 }
1512 _ => panic!(
1513 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1514 completion_error
1515 ),
1516 }
1517 }
1518}