1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use cloud_llm_client::{
4 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
5 CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
6 EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, OUTDATED_LLM_TOKEN_HEADER_NAME,
7 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
8};
9use futures::{
10 AsyncBufReadExt, FutureExt, Stream, StreamExt,
11 future::BoxFuture,
12 stream::{self, BoxStream},
13};
14use google_ai::GoogleModelMode;
15use gpui::{AppContext, AsyncApp, Context, Task};
16use http_client::http::{HeaderMap, HeaderValue};
17use http_client::{
18 AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
19};
20use language_model::{
21 ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
22 LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
23 LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
24 LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
25 LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
26 OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
27 ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
28};
29
30use schemars::JsonSchema;
31use semver::Version;
32use serde::{Deserialize, Serialize, de::DeserializeOwned};
33use smol::io::{AsyncReadExt, BufReader};
34use std::collections::VecDeque;
35use std::pin::Pin;
36use std::str::FromStr;
37use std::sync::Arc;
38use std::task::Poll;
39use std::time::Duration;
40use thiserror::Error;
41
42use anthropic::completion::{AnthropicEventMapper, into_anthropic};
43use google_ai::completion::{GoogleEventMapper, into_google};
44use open_ai::completion::{
45 OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response,
46};
47
48const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
49const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
50
51/// Trait for acquiring and refreshing LLM authentication tokens.
52pub trait CloudLlmTokenProvider: Send + Sync {
53 type AuthContext: Clone + Send + 'static;
54
55 fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
56 fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
57 fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
58}
59
60#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
61#[serde(tag = "type", rename_all = "lowercase")]
62pub enum ModelMode {
63 #[default]
64 Default,
65 Thinking {
66 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
67 budget_tokens: Option<u32>,
68 },
69}
70
71impl From<ModelMode> for AnthropicModelMode {
72 fn from(value: ModelMode) -> Self {
73 match value {
74 ModelMode::Default => AnthropicModelMode::Default,
75 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
76 }
77 }
78}
79
80pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
81 pub id: LanguageModelId,
82 pub model: Arc<cloud_llm_client::LanguageModel>,
83 pub token_provider: Arc<TP>,
84 pub http_client: Arc<HttpClientWithUrl>,
85 pub app_version: Option<Version>,
86 pub request_limiter: RateLimiter,
87}
88
89pub struct PerformLlmCompletionResponse {
90 pub response: Response<AsyncBody>,
91 pub includes_status_messages: bool,
92}
93
94impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
95 pub async fn perform_llm_completion(
96 http_client: &HttpClientWithUrl,
97 token_provider: &TP,
98 auth_context: TP::AuthContext,
99 app_version: Option<Version>,
100 body: CompletionBody,
101 ) -> Result<PerformLlmCompletionResponse> {
102 let mut token = token_provider.acquire_token(auth_context.clone()).await?;
103 let mut refreshed_token = false;
104
105 loop {
106 let request = http_client::Request::builder()
107 .method(Method::POST)
108 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
109 .when_some(app_version.as_ref(), |builder, app_version| {
110 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
111 })
112 .header("Content-Type", "application/json")
113 .header("Authorization", format!("Bearer {token}"))
114 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
115 .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
116 .body(serde_json::to_string(&body)?.into())?;
117
118 let mut response = http_client.send(request).await?;
119 let status = response.status();
120 if status.is_success() {
121 let includes_status_messages = response
122 .headers()
123 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
124 .is_some();
125
126 return Ok(PerformLlmCompletionResponse {
127 response,
128 includes_status_messages,
129 });
130 }
131
132 if !refreshed_token && needs_llm_token_refresh(&response) {
133 token = token_provider.refresh_token(auth_context.clone()).await?;
134 refreshed_token = true;
135 continue;
136 }
137
138 if status == StatusCode::PAYMENT_REQUIRED {
139 return Err(anyhow!(PaymentRequiredError));
140 }
141
142 let mut body = String::new();
143 let headers = response.headers().clone();
144 response.body_mut().read_to_string(&mut body).await?;
145 return Err(anyhow!(ApiError {
146 status,
147 body,
148 headers
149 }));
150 }
151 }
152}
153
154fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
155 response
156 .headers()
157 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
158 .is_some()
159 || response
160 .headers()
161 .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
162 .is_some()
163}
164
165#[derive(Debug, Error)]
166#[error("cloud language model request failed with status {status}: {body}")]
167struct ApiError {
168 status: StatusCode,
169 body: String,
170 headers: HeaderMap<HeaderValue>,
171}
172
173/// Represents error responses from Zed's cloud API.
174///
175/// Example JSON for an upstream HTTP error:
176/// ```json
177/// {
178/// "code": "upstream_http_error",
179/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
180/// "upstream_status": 503
181/// }
182/// ```
183#[derive(Debug, serde::Deserialize)]
184struct CloudApiError {
185 code: String,
186 message: String,
187 #[serde(default)]
188 #[serde(deserialize_with = "deserialize_optional_status_code")]
189 upstream_status: Option<StatusCode>,
190 #[serde(default)]
191 retry_after: Option<f64>,
192}
193
194fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
195where
196 D: serde::Deserializer<'de>,
197{
198 let opt: Option<u16> = Option::deserialize(deserializer)?;
199 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
200}
201
202impl From<ApiError> for LanguageModelCompletionError {
203 fn from(error: ApiError) -> Self {
204 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
205 if cloud_error.code.starts_with("upstream_http_") {
206 let status = if let Some(status) = cloud_error.upstream_status {
207 status
208 } else if cloud_error.code.ends_with("_error") {
209 error.status
210 } else {
211 // If there's a status code in the code string (e.g. "upstream_http_429")
212 // then use that; otherwise, see if the JSON contains a status code.
213 cloud_error
214 .code
215 .strip_prefix("upstream_http_")
216 .and_then(|code_str| code_str.parse::<u16>().ok())
217 .and_then(|code| StatusCode::from_u16(code).ok())
218 .unwrap_or(error.status)
219 };
220
221 return LanguageModelCompletionError::UpstreamProviderError {
222 message: cloud_error.message,
223 status,
224 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
225 };
226 }
227
228 return LanguageModelCompletionError::from_http_status(
229 PROVIDER_NAME,
230 error.status,
231 cloud_error.message,
232 None,
233 );
234 }
235
236 let retry_after = None;
237 LanguageModelCompletionError::from_http_status(
238 PROVIDER_NAME,
239 error.status,
240 error.body,
241 retry_after,
242 )
243 }
244}
245
246impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
247 fn id(&self) -> LanguageModelId {
248 self.id.clone()
249 }
250
251 fn name(&self) -> LanguageModelName {
252 LanguageModelName::from(self.model.display_name.clone())
253 }
254
255 fn provider_id(&self) -> LanguageModelProviderId {
256 PROVIDER_ID
257 }
258
259 fn provider_name(&self) -> LanguageModelProviderName {
260 PROVIDER_NAME
261 }
262
263 fn upstream_provider_id(&self) -> LanguageModelProviderId {
264 use cloud_llm_client::LanguageModelProvider::*;
265 match self.model.provider {
266 Anthropic => ANTHROPIC_PROVIDER_ID,
267 OpenAi => OPEN_AI_PROVIDER_ID,
268 Google => GOOGLE_PROVIDER_ID,
269 XAi => X_AI_PROVIDER_ID,
270 }
271 }
272
273 fn upstream_provider_name(&self) -> LanguageModelProviderName {
274 use cloud_llm_client::LanguageModelProvider::*;
275 match self.model.provider {
276 Anthropic => ANTHROPIC_PROVIDER_NAME,
277 OpenAi => OPEN_AI_PROVIDER_NAME,
278 Google => GOOGLE_PROVIDER_NAME,
279 XAi => X_AI_PROVIDER_NAME,
280 }
281 }
282
283 fn is_latest(&self) -> bool {
284 self.model.is_latest
285 }
286
287 fn supports_tools(&self) -> bool {
288 self.model.supports_tools
289 }
290
291 fn supports_images(&self) -> bool {
292 self.model.supports_images
293 }
294
295 fn supports_thinking(&self) -> bool {
296 self.model.supports_thinking
297 }
298
299 fn supports_fast_mode(&self) -> bool {
300 self.model.supports_fast_mode
301 }
302
303 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
304 self.model
305 .supported_effort_levels
306 .iter()
307 .map(|effort_level| LanguageModelEffortLevel {
308 name: effort_level.name.clone().into(),
309 value: effort_level.value.clone().into(),
310 is_default: effort_level.is_default.unwrap_or(false),
311 })
312 .collect()
313 }
314
315 fn supports_streaming_tools(&self) -> bool {
316 self.model.supports_streaming_tools
317 }
318
319 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
320 match choice {
321 LanguageModelToolChoice::Auto
322 | LanguageModelToolChoice::Any
323 | LanguageModelToolChoice::None => true,
324 }
325 }
326
327 fn supports_split_token_display(&self) -> bool {
328 use cloud_llm_client::LanguageModelProvider::*;
329 matches!(self.model.provider, OpenAi | XAi)
330 }
331
332 fn telemetry_id(&self) -> String {
333 format!("zed.dev/{}", self.model.id)
334 }
335
336 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
337 match self.model.provider {
338 cloud_llm_client::LanguageModelProvider::Anthropic
339 | cloud_llm_client::LanguageModelProvider::OpenAi => {
340 LanguageModelToolSchemaFormat::JsonSchema
341 }
342 cloud_llm_client::LanguageModelProvider::Google
343 | cloud_llm_client::LanguageModelProvider::XAi => {
344 LanguageModelToolSchemaFormat::JsonSchemaSubset
345 }
346 }
347 }
348
349 fn max_token_count(&self) -> u64 {
350 self.model.max_token_count as u64
351 }
352
353 fn max_output_tokens(&self) -> Option<u64> {
354 Some(self.model.max_output_tokens as u64)
355 }
356
357 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
358 match &self.model.provider {
359 cloud_llm_client::LanguageModelProvider::Anthropic => {
360 Some(LanguageModelCacheConfiguration {
361 min_total_token: 2_048,
362 should_speculate: true,
363 max_cache_anchors: 4,
364 })
365 }
366 cloud_llm_client::LanguageModelProvider::OpenAi
367 | cloud_llm_client::LanguageModelProvider::XAi
368 | cloud_llm_client::LanguageModelProvider::Google => None,
369 }
370 }
371
372 fn stream_completion(
373 &self,
374 request: LanguageModelRequest,
375 cx: &AsyncApp,
376 ) -> BoxFuture<
377 'static,
378 Result<
379 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
380 LanguageModelCompletionError,
381 >,
382 > {
383 let thread_id = request.thread_id.clone();
384 let prompt_id = request.prompt_id.clone();
385 let app_version = self.app_version.clone();
386 let thinking_allowed = request.thinking_allowed;
387 let enable_thinking = thinking_allowed && self.model.supports_thinking;
388 let provider_name = provider_name(&self.model.provider);
389 match self.model.provider {
390 cloud_llm_client::LanguageModelProvider::Anthropic => {
391 let effort = request
392 .thinking_effort
393 .as_ref()
394 .and_then(|effort| anthropic::Effort::from_str(effort).ok());
395
396 let mut request = into_anthropic(
397 request,
398 self.model.id.to_string(),
399 1.0,
400 self.model.max_output_tokens as u64,
401 if enable_thinking {
402 AnthropicModelMode::Thinking {
403 budget_tokens: Some(4_096),
404 }
405 } else {
406 AnthropicModelMode::Default
407 },
408 );
409
410 if enable_thinking && effort.is_some() {
411 request.thinking = Some(anthropic::Thinking::Adaptive);
412 request.output_config = Some(anthropic::OutputConfig { effort });
413 }
414
415 if !self.model.supports_fast_mode {
416 request.speed = None;
417 }
418
419 let http_client = self.http_client.clone();
420 let token_provider = self.token_provider.clone();
421 let auth_context = token_provider.auth_context(cx);
422 let future = self.request_limiter.stream(async move {
423 let PerformLlmCompletionResponse {
424 response,
425 includes_status_messages,
426 } = Self::perform_llm_completion(
427 &http_client,
428 &*token_provider,
429 auth_context,
430 app_version,
431 CompletionBody {
432 thread_id,
433 prompt_id,
434 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
435 model: request.model.clone(),
436 provider_request: serde_json::to_value(&request)
437 .map_err(|e| anyhow!(e))?,
438 },
439 )
440 .await
441 .map_err(|err| match err.downcast::<ApiError>() {
442 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
443 Err(err) => anyhow!(err),
444 })?;
445
446 let mut mapper = AnthropicEventMapper::new();
447 Ok(map_cloud_completion_events(
448 Box::pin(response_lines(response, includes_status_messages)),
449 &provider_name,
450 move |event| mapper.map_event(event),
451 ))
452 });
453 async move { Ok(future.await?.boxed()) }.boxed()
454 }
455 cloud_llm_client::LanguageModelProvider::OpenAi => {
456 let http_client = self.http_client.clone();
457 let token_provider = self.token_provider.clone();
458 let effort = request
459 .thinking_effort
460 .as_ref()
461 .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
462
463 let mut request = into_open_ai_response(
464 request,
465 &self.model.id.0,
466 self.model.supports_parallel_tool_calls,
467 true,
468 None,
469 None,
470 );
471
472 if enable_thinking && let Some(effort) = effort {
473 request.reasoning = Some(open_ai::responses::ReasoningConfig {
474 effort,
475 summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
476 });
477 }
478
479 let auth_context = token_provider.auth_context(cx);
480 let future = self.request_limiter.stream(async move {
481 let PerformLlmCompletionResponse {
482 response,
483 includes_status_messages,
484 } = Self::perform_llm_completion(
485 &http_client,
486 &*token_provider,
487 auth_context,
488 app_version,
489 CompletionBody {
490 thread_id,
491 prompt_id,
492 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
493 model: request.model.clone(),
494 provider_request: serde_json::to_value(&request)
495 .map_err(|e| anyhow!(e))?,
496 },
497 )
498 .await?;
499
500 let mut mapper = OpenAiResponseEventMapper::new();
501 Ok(map_cloud_completion_events(
502 Box::pin(response_lines(response, includes_status_messages)),
503 &provider_name,
504 move |event| mapper.map_event(event),
505 ))
506 });
507 async move { Ok(future.await?.boxed()) }.boxed()
508 }
509 cloud_llm_client::LanguageModelProvider::XAi => {
510 let http_client = self.http_client.clone();
511 let token_provider = self.token_provider.clone();
512 let request = into_open_ai(
513 request,
514 &self.model.id.0,
515 self.model.supports_parallel_tool_calls,
516 false,
517 None,
518 None,
519 false,
520 );
521 let auth_context = token_provider.auth_context(cx);
522 let future = self.request_limiter.stream(async move {
523 let PerformLlmCompletionResponse {
524 response,
525 includes_status_messages,
526 } = Self::perform_llm_completion(
527 &http_client,
528 &*token_provider,
529 auth_context,
530 app_version,
531 CompletionBody {
532 thread_id,
533 prompt_id,
534 provider: cloud_llm_client::LanguageModelProvider::XAi,
535 model: request.model.clone(),
536 provider_request: serde_json::to_value(&request)
537 .map_err(|e| anyhow!(e))?,
538 },
539 )
540 .await?;
541
542 let mut mapper = OpenAiEventMapper::new();
543 Ok(map_cloud_completion_events(
544 Box::pin(response_lines(response, includes_status_messages)),
545 &provider_name,
546 move |event| mapper.map_event(event),
547 ))
548 });
549 async move { Ok(future.await?.boxed()) }.boxed()
550 }
551 cloud_llm_client::LanguageModelProvider::Google => {
552 let http_client = self.http_client.clone();
553 let token_provider = self.token_provider.clone();
554 let request =
555 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
556 let auth_context = token_provider.auth_context(cx);
557 let future = self.request_limiter.stream(async move {
558 let PerformLlmCompletionResponse {
559 response,
560 includes_status_messages,
561 } = Self::perform_llm_completion(
562 &http_client,
563 &*token_provider,
564 auth_context,
565 app_version,
566 CompletionBody {
567 thread_id,
568 prompt_id,
569 provider: cloud_llm_client::LanguageModelProvider::Google,
570 model: request.model.model_id.clone(),
571 provider_request: serde_json::to_value(&request)
572 .map_err(|e| anyhow!(e))?,
573 },
574 )
575 .await?;
576
577 let mut mapper = GoogleEventMapper::new();
578 Ok(map_cloud_completion_events(
579 Box::pin(response_lines(response, includes_status_messages)),
580 &provider_name,
581 move |event| mapper.map_event(event),
582 ))
583 });
584 async move { Ok(future.await?.boxed()) }.boxed()
585 }
586 }
587 }
588}
589
590pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
591 token_provider: Arc<TP>,
592 http_client: Arc<HttpClientWithUrl>,
593 app_version: Option<Version>,
594 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
595 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
596 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
597 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
598}
599
600impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
601 pub fn new(
602 token_provider: Arc<TP>,
603 http_client: Arc<HttpClientWithUrl>,
604 app_version: Option<Version>,
605 ) -> Self {
606 Self {
607 token_provider,
608 http_client,
609 app_version,
610 models: Vec::new(),
611 default_model: None,
612 default_fast_model: None,
613 recommended_models: Vec::new(),
614 }
615 }
616
617 pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
618 let http_client = self.http_client.clone();
619 let token_provider = self.token_provider.clone();
620 cx.spawn(async move |this, cx| {
621 let auth_context = token_provider.auth_context(cx);
622 let response =
623 Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
624 this.update(cx, |this, cx| {
625 this.update_models(response);
626 cx.notify();
627 })
628 })
629 }
630
631 async fn fetch_models_request(
632 http_client: &HttpClientWithUrl,
633 token_provider: &TP,
634 auth_context: TP::AuthContext,
635 ) -> Result<ListModelsResponse> {
636 let token = token_provider.acquire_token(auth_context).await?;
637
638 let request = http_client::Request::builder()
639 .method(Method::GET)
640 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
641 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
642 .header("Authorization", format!("Bearer {token}"))
643 .body(AsyncBody::empty())?;
644 let mut response = http_client
645 .send(request)
646 .await
647 .context("failed to send list models request")?;
648
649 if response.status().is_success() {
650 let mut body = String::new();
651 response.body_mut().read_to_string(&mut body).await?;
652 Ok(serde_json::from_str(&body)?)
653 } else {
654 let mut body = String::new();
655 response.body_mut().read_to_string(&mut body).await?;
656 anyhow::bail!(
657 "error listing models.\nStatus: {:?}\nBody: {body}",
658 response.status(),
659 );
660 }
661 }
662
663 pub fn update_models(&mut self, response: ListModelsResponse) {
664 let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
665
666 self.default_model = models
667 .iter()
668 .find(|model| {
669 response
670 .default_model
671 .as_ref()
672 .is_some_and(|default_model_id| &model.id == default_model_id)
673 })
674 .cloned();
675 self.default_fast_model = models
676 .iter()
677 .find(|model| {
678 response
679 .default_fast_model
680 .as_ref()
681 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
682 })
683 .cloned();
684 self.recommended_models = response
685 .recommended_models
686 .iter()
687 .filter_map(|id| models.iter().find(|model| &model.id == id))
688 .cloned()
689 .collect();
690 self.models = models;
691 }
692
693 pub fn create_model(
694 &self,
695 model: &Arc<cloud_llm_client::LanguageModel>,
696 ) -> Arc<dyn LanguageModel> {
697 Arc::new(CloudLanguageModel::<TP> {
698 id: LanguageModelId::from(model.id.0.to_string()),
699 model: model.clone(),
700 token_provider: self.token_provider.clone(),
701 http_client: self.http_client.clone(),
702 app_version: self.app_version.clone(),
703 request_limiter: RateLimiter::new(4),
704 })
705 }
706
707 pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
708 &self.models
709 }
710
711 pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
712 self.default_model.as_ref()
713 }
714
715 pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
716 self.default_fast_model.as_ref()
717 }
718
719 pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
720 &self.recommended_models
721 }
722}
723
724pub fn map_cloud_completion_events<T, F>(
725 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
726 provider: &LanguageModelProviderName,
727 mut map_callback: F,
728) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
729where
730 T: DeserializeOwned + 'static,
731 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
732 + Send
733 + 'static,
734{
735 let provider = provider.clone();
736 let mut stream = stream.fuse();
737
738 let mut saw_stream_ended = false;
739
740 let mut done = false;
741 let mut pending = VecDeque::new();
742
743 stream::poll_fn(move |cx| {
744 loop {
745 if let Some(item) = pending.pop_front() {
746 return Poll::Ready(Some(item));
747 }
748
749 if done {
750 return Poll::Ready(None);
751 }
752
753 match stream.poll_next_unpin(cx) {
754 Poll::Ready(Some(event)) => {
755 let items = match event {
756 Err(error) => {
757 vec![Err(LanguageModelCompletionError::from(error))]
758 }
759 Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
760 saw_stream_ended = true;
761 vec![]
762 }
763 Ok(CompletionEvent::Status(status)) => {
764 LanguageModelCompletionEvent::from_completion_request_status(
765 status,
766 provider.clone(),
767 )
768 .transpose()
769 .map(|event| vec![event])
770 .unwrap_or_default()
771 }
772 Ok(CompletionEvent::Event(event)) => map_callback(event),
773 };
774 pending.extend(items);
775 }
776 Poll::Ready(None) => {
777 done = true;
778
779 if !saw_stream_ended {
780 return Poll::Ready(Some(Err(
781 LanguageModelCompletionError::StreamEndedUnexpectedly {
782 provider: provider.clone(),
783 },
784 )));
785 }
786 }
787 Poll::Pending => return Poll::Pending,
788 }
789 }
790 })
791 .boxed()
792}
793
794pub fn provider_name(
795 provider: &cloud_llm_client::LanguageModelProvider,
796) -> LanguageModelProviderName {
797 match provider {
798 cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
799 cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
800 cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
801 cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
802 }
803}
804
805pub fn response_lines<T: DeserializeOwned>(
806 response: Response<AsyncBody>,
807 includes_status_messages: bool,
808) -> impl Stream<Item = Result<CompletionEvent<T>>> {
809 futures::stream::try_unfold(
810 (String::new(), BufReader::new(response.into_body())),
811 move |(mut line, mut body)| async move {
812 match body.read_line(&mut line).await {
813 Ok(0) => Ok(None),
814 Ok(_) => {
815 let event = if includes_status_messages {
816 serde_json::from_str::<CompletionEvent<T>>(&line)?
817 } else {
818 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
819 };
820
821 line.clear();
822 Ok(Some((event, (line, body))))
823 }
824 Err(e) => Err(e.into()),
825 }
826 },
827 )
828}
829
830#[cfg(test)]
831mod tests {
832 use super::*;
833 use http_client::http::{HeaderMap, StatusCode};
834 use language_model::LanguageModelCompletionError;
835
836 #[test]
837 fn test_api_error_conversion_with_upstream_http_error() {
838 // upstream_http_error with 503 status should become ServerOverloaded
839 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
840
841 let api_error = ApiError {
842 status: StatusCode::INTERNAL_SERVER_ERROR,
843 body: error_body.to_string(),
844 headers: HeaderMap::new(),
845 };
846
847 let completion_error: LanguageModelCompletionError = api_error.into();
848
849 match completion_error {
850 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
851 assert_eq!(
852 message,
853 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
854 );
855 }
856 _ => panic!(
857 "Expected UpstreamProviderError for upstream 503, got: {:?}",
858 completion_error
859 ),
860 }
861
862 // upstream_http_error with 500 status should become ApiInternalServerError
863 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
864
865 let api_error = ApiError {
866 status: StatusCode::INTERNAL_SERVER_ERROR,
867 body: error_body.to_string(),
868 headers: HeaderMap::new(),
869 };
870
871 let completion_error: LanguageModelCompletionError = api_error.into();
872
873 match completion_error {
874 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
875 assert_eq!(
876 message,
877 "Received an error from the OpenAI API: internal server error"
878 );
879 }
880 _ => panic!(
881 "Expected UpstreamProviderError for upstream 500, got: {:?}",
882 completion_error
883 ),
884 }
885
886 // upstream_http_error with 429 status should become RateLimitExceeded
887 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
888
889 let api_error = ApiError {
890 status: StatusCode::INTERNAL_SERVER_ERROR,
891 body: error_body.to_string(),
892 headers: HeaderMap::new(),
893 };
894
895 let completion_error: LanguageModelCompletionError = api_error.into();
896
897 match completion_error {
898 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
899 assert_eq!(
900 message,
901 "Received an error from the Google API: rate limit exceeded"
902 );
903 }
904 _ => panic!(
905 "Expected UpstreamProviderError for upstream 429, got: {:?}",
906 completion_error
907 ),
908 }
909
910 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
911 let error_body = "Regular internal server error";
912
913 let api_error = ApiError {
914 status: StatusCode::INTERNAL_SERVER_ERROR,
915 body: error_body.to_string(),
916 headers: HeaderMap::new(),
917 };
918
919 let completion_error: LanguageModelCompletionError = api_error.into();
920
921 match completion_error {
922 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
923 assert_eq!(provider, PROVIDER_NAME);
924 assert_eq!(message, "Regular internal server error");
925 }
926 _ => panic!(
927 "Expected ApiInternalServerError for regular 500, got: {:?}",
928 completion_error
929 ),
930 }
931
932 // upstream_http_429 format should be converted to UpstreamProviderError
933 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
934
935 let api_error = ApiError {
936 status: StatusCode::INTERNAL_SERVER_ERROR,
937 body: error_body.to_string(),
938 headers: HeaderMap::new(),
939 };
940
941 let completion_error: LanguageModelCompletionError = api_error.into();
942
943 match completion_error {
944 LanguageModelCompletionError::UpstreamProviderError {
945 message,
946 status,
947 retry_after,
948 } => {
949 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
950 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
951 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
952 }
953 _ => panic!(
954 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
955 completion_error
956 ),
957 }
958
959 // Invalid JSON in error body should fall back to regular error handling
960 let error_body = "Not JSON at all";
961
962 let api_error = ApiError {
963 status: StatusCode::INTERNAL_SERVER_ERROR,
964 body: error_body.to_string(),
965 headers: HeaderMap::new(),
966 };
967
968 let completion_error: LanguageModelCompletionError = api_error.into();
969
970 match completion_error {
971 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
972 assert_eq!(provider, PROVIDER_NAME);
973 }
974 _ => panic!(
975 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
976 completion_error
977 ),
978 }
979 }
980}