1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use cloud_llm_client::{
4 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
5 CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
6 EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, OUTDATED_LLM_TOKEN_HEADER_NAME,
7 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
8};
9use futures::{
10 AsyncBufReadExt, AsyncReadExt as _, FutureExt, Stream, StreamExt,
11 future::BoxFuture,
12 io::BufReader,
13 stream::{self, BoxStream},
14};
15use google_ai::GoogleModelMode;
16use gpui::{AppContext, AsyncApp, Context, Task};
17use http_client::http::{HeaderMap, HeaderValue};
18use http_client::{
19 AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
20};
21use language_model::{
22 ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
23 LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
24 LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
25 LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
26 LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
27 OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
28 ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
29};
30
31use schemars::JsonSchema;
32use semver::Version;
33use serde::{Deserialize, Serialize, de::DeserializeOwned};
34use std::collections::VecDeque;
35use std::pin::Pin;
36use std::str::FromStr;
37use std::sync::Arc;
38use std::task::Poll;
39use std::time::Duration;
40use thiserror::Error;
41
42use anthropic::completion::{AnthropicEventMapper, into_anthropic};
43use google_ai::completion::{GoogleEventMapper, into_google};
44use open_ai::completion::{
45 OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response,
46};
47
48const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
49const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
50
51/// Trait for acquiring and refreshing LLM authentication tokens.
52pub trait CloudLlmTokenProvider: Send + Sync {
53 type AuthContext: Clone + Send + 'static;
54
55 fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
56 fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
57 fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
58}
59
60#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
61#[serde(tag = "type", rename_all = "lowercase")]
62pub enum ModelMode {
63 #[default]
64 Default,
65 Thinking {
66 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
67 budget_tokens: Option<u32>,
68 },
69}
70
71impl From<ModelMode> for AnthropicModelMode {
72 fn from(value: ModelMode) -> Self {
73 match value {
74 ModelMode::Default => AnthropicModelMode::Default,
75 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
76 }
77 }
78}
79
80pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
81 pub id: LanguageModelId,
82 pub model: Arc<cloud_llm_client::LanguageModel>,
83 pub token_provider: Arc<TP>,
84 pub http_client: Arc<HttpClientWithUrl>,
85 pub app_version: Option<Version>,
86 pub request_limiter: RateLimiter,
87}
88
89pub struct PerformLlmCompletionResponse {
90 pub response: Response<AsyncBody>,
91 pub includes_status_messages: bool,
92}
93
94impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
95 pub async fn perform_llm_completion(
96 http_client: &HttpClientWithUrl,
97 token_provider: &TP,
98 auth_context: TP::AuthContext,
99 app_version: Option<Version>,
100 body: CompletionBody,
101 ) -> Result<PerformLlmCompletionResponse> {
102 let mut token = token_provider.acquire_token(auth_context.clone()).await?;
103 let mut refreshed_token = false;
104
105 loop {
106 let request = http_client::Request::builder()
107 .method(Method::POST)
108 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
109 .when_some(app_version.as_ref(), |builder, app_version| {
110 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
111 })
112 .header("Content-Type", "application/json")
113 .header("Authorization", format!("Bearer {token}"))
114 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
115 .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
116 .body(serde_json::to_string(&body)?.into())?;
117
118 let mut response = http_client.send(request).await?;
119 let status = response.status();
120 if status.is_success() {
121 let includes_status_messages = response
122 .headers()
123 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
124 .is_some();
125
126 return Ok(PerformLlmCompletionResponse {
127 response,
128 includes_status_messages,
129 });
130 }
131
132 if !refreshed_token && needs_llm_token_refresh(&response) {
133 token = token_provider.refresh_token(auth_context.clone()).await?;
134 refreshed_token = true;
135 continue;
136 }
137
138 if status == StatusCode::PAYMENT_REQUIRED {
139 return Err(anyhow!(PaymentRequiredError));
140 }
141
142 let mut body = String::new();
143 let headers = response.headers().clone();
144 response.body_mut().read_to_string(&mut body).await?;
145 return Err(anyhow!(ApiError {
146 status,
147 body,
148 headers
149 }));
150 }
151 }
152}
153
154fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
155 response
156 .headers()
157 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
158 .is_some()
159 || response
160 .headers()
161 .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
162 .is_some()
163}
164
165#[derive(Debug, Error)]
166#[error("cloud language model request failed with status {status}: {body}")]
167struct ApiError {
168 status: StatusCode,
169 body: String,
170 headers: HeaderMap<HeaderValue>,
171}
172
173/// Represents error responses from Zed's cloud API.
174///
175/// Example JSON for an upstream HTTP error:
176/// ```json
177/// {
178/// "code": "upstream_http_error",
179/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
180/// "upstream_status": 503
181/// }
182/// ```
183#[derive(Debug, serde::Deserialize)]
184struct CloudApiError {
185 code: String,
186 message: String,
187 #[serde(default)]
188 #[serde(deserialize_with = "deserialize_optional_status_code")]
189 upstream_status: Option<StatusCode>,
190 #[serde(default)]
191 retry_after: Option<f64>,
192}
193
194fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
195where
196 D: serde::Deserializer<'de>,
197{
198 let opt: Option<u16> = Option::deserialize(deserializer)?;
199 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
200}
201
202impl From<ApiError> for LanguageModelCompletionError {
203 fn from(error: ApiError) -> Self {
204 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
205 if cloud_error.code.starts_with("upstream_http_") {
206 let status = if let Some(status) = cloud_error.upstream_status {
207 status
208 } else if cloud_error.code.ends_with("_error") {
209 error.status
210 } else {
211 // If there's a status code in the code string (e.g. "upstream_http_429")
212 // then use that; otherwise, see if the JSON contains a status code.
213 cloud_error
214 .code
215 .strip_prefix("upstream_http_")
216 .and_then(|code_str| code_str.parse::<u16>().ok())
217 .and_then(|code| StatusCode::from_u16(code).ok())
218 .unwrap_or(error.status)
219 };
220
221 return LanguageModelCompletionError::UpstreamProviderError {
222 message: cloud_error.message,
223 status,
224 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
225 };
226 }
227
228 return LanguageModelCompletionError::from_http_status(
229 PROVIDER_NAME,
230 error.status,
231 cloud_error.message,
232 None,
233 );
234 }
235
236 let retry_after = None;
237 LanguageModelCompletionError::from_http_status(
238 PROVIDER_NAME,
239 error.status,
240 error.body,
241 retry_after,
242 )
243 }
244}
245
246impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
247 fn id(&self) -> LanguageModelId {
248 self.id.clone()
249 }
250
251 fn name(&self) -> LanguageModelName {
252 LanguageModelName::from(self.model.display_name.clone())
253 }
254
255 fn provider_id(&self) -> LanguageModelProviderId {
256 PROVIDER_ID
257 }
258
259 fn provider_name(&self) -> LanguageModelProviderName {
260 PROVIDER_NAME
261 }
262
263 fn upstream_provider_id(&self) -> LanguageModelProviderId {
264 use cloud_llm_client::LanguageModelProvider::*;
265 match self.model.provider {
266 Anthropic => ANTHROPIC_PROVIDER_ID,
267 OpenAi => OPEN_AI_PROVIDER_ID,
268 Google => GOOGLE_PROVIDER_ID,
269 XAi => X_AI_PROVIDER_ID,
270 }
271 }
272
273 fn upstream_provider_name(&self) -> LanguageModelProviderName {
274 use cloud_llm_client::LanguageModelProvider::*;
275 match self.model.provider {
276 Anthropic => ANTHROPIC_PROVIDER_NAME,
277 OpenAi => OPEN_AI_PROVIDER_NAME,
278 Google => GOOGLE_PROVIDER_NAME,
279 XAi => X_AI_PROVIDER_NAME,
280 }
281 }
282
283 fn is_latest(&self) -> bool {
284 self.model.is_latest
285 }
286
287 fn supports_tools(&self) -> bool {
288 self.model.supports_tools
289 }
290
291 fn supports_images(&self) -> bool {
292 self.model.supports_images
293 }
294
295 fn supports_thinking(&self) -> bool {
296 self.model.supports_thinking
297 }
298
299 fn supports_fast_mode(&self) -> bool {
300 self.model.supports_fast_mode
301 }
302
303 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
304 self.model
305 .supported_effort_levels
306 .iter()
307 .map(|effort_level| LanguageModelEffortLevel {
308 name: effort_level.name.clone().into(),
309 value: effort_level.value.clone().into(),
310 is_default: effort_level.is_default.unwrap_or(false),
311 })
312 .collect()
313 }
314
315 fn supports_streaming_tools(&self) -> bool {
316 self.model.supports_streaming_tools
317 }
318
319 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
320 match choice {
321 LanguageModelToolChoice::Auto
322 | LanguageModelToolChoice::Any
323 | LanguageModelToolChoice::None => true,
324 }
325 }
326
327 fn supports_split_token_display(&self) -> bool {
328 use cloud_llm_client::LanguageModelProvider::*;
329 matches!(self.model.provider, OpenAi | XAi)
330 }
331
332 fn telemetry_id(&self) -> String {
333 format!("zed.dev/{}", self.model.id)
334 }
335
336 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
337 match self.model.provider {
338 cloud_llm_client::LanguageModelProvider::Anthropic
339 | cloud_llm_client::LanguageModelProvider::OpenAi => {
340 LanguageModelToolSchemaFormat::JsonSchema
341 }
342 cloud_llm_client::LanguageModelProvider::Google
343 | cloud_llm_client::LanguageModelProvider::XAi => {
344 LanguageModelToolSchemaFormat::JsonSchemaSubset
345 }
346 }
347 }
348
349 fn max_token_count(&self) -> u64 {
350 self.model.max_token_count as u64
351 }
352
353 fn max_output_tokens(&self) -> Option<u64> {
354 Some(self.model.max_output_tokens as u64)
355 }
356
357 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
358 match &self.model.provider {
359 cloud_llm_client::LanguageModelProvider::Anthropic => {
360 Some(LanguageModelCacheConfiguration {
361 min_total_token: 2_048,
362 should_speculate: true,
363 max_cache_anchors: 4,
364 })
365 }
366 cloud_llm_client::LanguageModelProvider::OpenAi
367 | cloud_llm_client::LanguageModelProvider::XAi
368 | cloud_llm_client::LanguageModelProvider::Google => None,
369 }
370 }
371
372 fn stream_completion(
373 &self,
374 request: LanguageModelRequest,
375 cx: &AsyncApp,
376 ) -> BoxFuture<
377 'static,
378 Result<
379 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
380 LanguageModelCompletionError,
381 >,
382 > {
383 let thread_id = request.thread_id.clone();
384 let prompt_id = request.prompt_id.clone();
385 let app_version = self.app_version.clone();
386 let thinking_allowed = request.thinking_allowed;
387 let enable_thinking = thinking_allowed && self.model.supports_thinking;
388 let provider_name = provider_name(&self.model.provider);
389 match self.model.provider {
390 cloud_llm_client::LanguageModelProvider::Anthropic => {
391 let effort = request
392 .thinking_effort
393 .as_ref()
394 .and_then(|effort| anthropic::Effort::from_str(effort).ok());
395
396 let mut request = into_anthropic(
397 request,
398 self.model.id.to_string(),
399 1.0,
400 self.model.max_output_tokens as u64,
401 if enable_thinking {
402 AnthropicModelMode::Thinking {
403 budget_tokens: Some(4_096),
404 }
405 } else {
406 AnthropicModelMode::Default
407 },
408 );
409
410 if enable_thinking && effort.is_some() {
411 request.thinking = Some(anthropic::Thinking::Adaptive {
412 display: Some(anthropic::AdaptiveThinkingDisplay::Summarized),
413 });
414 request.output_config = Some(anthropic::OutputConfig { effort });
415 }
416
417 if !self.model.supports_fast_mode {
418 request.speed = None;
419 }
420
421 let http_client = self.http_client.clone();
422 let token_provider = self.token_provider.clone();
423 let auth_context = token_provider.auth_context(cx);
424 let future = self.request_limiter.stream(async move {
425 let PerformLlmCompletionResponse {
426 response,
427 includes_status_messages,
428 } = Self::perform_llm_completion(
429 &http_client,
430 &*token_provider,
431 auth_context,
432 app_version,
433 CompletionBody {
434 thread_id,
435 prompt_id,
436 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
437 model: request.model.clone(),
438 provider_request: serde_json::to_value(&request)
439 .map_err(|e| anyhow!(e))?,
440 },
441 )
442 .await
443 .map_err(|err| match err.downcast::<ApiError>() {
444 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
445 Err(err) => anyhow!(err),
446 })?;
447
448 let mut mapper = AnthropicEventMapper::new();
449 Ok(map_cloud_completion_events(
450 Box::pin(response_lines(response, includes_status_messages)),
451 &provider_name,
452 move |event| mapper.map_event(event),
453 ))
454 });
455 async move { Ok(future.await?.boxed()) }.boxed()
456 }
457 cloud_llm_client::LanguageModelProvider::OpenAi => {
458 let http_client = self.http_client.clone();
459 let token_provider = self.token_provider.clone();
460 let effort = request
461 .thinking_effort
462 .as_ref()
463 .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
464
465 let mut request = into_open_ai_response(
466 request,
467 &self.model.id.0,
468 self.model.supports_parallel_tool_calls,
469 true,
470 None,
471 None,
472 );
473
474 if enable_thinking && let Some(effort) = effort {
475 request.reasoning = Some(open_ai::responses::ReasoningConfig {
476 effort,
477 summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
478 });
479 }
480
481 let auth_context = token_provider.auth_context(cx);
482 let future = self.request_limiter.stream(async move {
483 let PerformLlmCompletionResponse {
484 response,
485 includes_status_messages,
486 } = Self::perform_llm_completion(
487 &http_client,
488 &*token_provider,
489 auth_context,
490 app_version,
491 CompletionBody {
492 thread_id,
493 prompt_id,
494 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
495 model: request.model.clone(),
496 provider_request: serde_json::to_value(&request)
497 .map_err(|e| anyhow!(e))?,
498 },
499 )
500 .await?;
501
502 let mut mapper = OpenAiResponseEventMapper::new();
503 Ok(map_cloud_completion_events(
504 Box::pin(response_lines(response, includes_status_messages)),
505 &provider_name,
506 move |event| mapper.map_event(event),
507 ))
508 });
509 async move { Ok(future.await?.boxed()) }.boxed()
510 }
511 cloud_llm_client::LanguageModelProvider::XAi => {
512 let http_client = self.http_client.clone();
513 let token_provider = self.token_provider.clone();
514 let request = into_open_ai(
515 request,
516 &self.model.id.0,
517 self.model.supports_parallel_tool_calls,
518 false,
519 None,
520 None,
521 false,
522 );
523 let auth_context = token_provider.auth_context(cx);
524 let future = self.request_limiter.stream(async move {
525 let PerformLlmCompletionResponse {
526 response,
527 includes_status_messages,
528 } = Self::perform_llm_completion(
529 &http_client,
530 &*token_provider,
531 auth_context,
532 app_version,
533 CompletionBody {
534 thread_id,
535 prompt_id,
536 provider: cloud_llm_client::LanguageModelProvider::XAi,
537 model: request.model.clone(),
538 provider_request: serde_json::to_value(&request)
539 .map_err(|e| anyhow!(e))?,
540 },
541 )
542 .await?;
543
544 let mut mapper = OpenAiEventMapper::new();
545 Ok(map_cloud_completion_events(
546 Box::pin(response_lines(response, includes_status_messages)),
547 &provider_name,
548 move |event| mapper.map_event(event),
549 ))
550 });
551 async move { Ok(future.await?.boxed()) }.boxed()
552 }
553 cloud_llm_client::LanguageModelProvider::Google => {
554 let http_client = self.http_client.clone();
555 let token_provider = self.token_provider.clone();
556 let request =
557 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
558 let auth_context = token_provider.auth_context(cx);
559 let future = self.request_limiter.stream(async move {
560 let PerformLlmCompletionResponse {
561 response,
562 includes_status_messages,
563 } = Self::perform_llm_completion(
564 &http_client,
565 &*token_provider,
566 auth_context,
567 app_version,
568 CompletionBody {
569 thread_id,
570 prompt_id,
571 provider: cloud_llm_client::LanguageModelProvider::Google,
572 model: request.model.model_id.clone(),
573 provider_request: serde_json::to_value(&request)
574 .map_err(|e| anyhow!(e))?,
575 },
576 )
577 .await?;
578
579 let mut mapper = GoogleEventMapper::new();
580 Ok(map_cloud_completion_events(
581 Box::pin(response_lines(response, includes_status_messages)),
582 &provider_name,
583 move |event| mapper.map_event(event),
584 ))
585 });
586 async move { Ok(future.await?.boxed()) }.boxed()
587 }
588 }
589 }
590}
591
592pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
593 token_provider: Arc<TP>,
594 http_client: Arc<HttpClientWithUrl>,
595 app_version: Option<Version>,
596 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
597 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
598 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
599 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
600}
601
602impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
603 pub fn new(
604 token_provider: Arc<TP>,
605 http_client: Arc<HttpClientWithUrl>,
606 app_version: Option<Version>,
607 ) -> Self {
608 Self {
609 token_provider,
610 http_client,
611 app_version,
612 models: Vec::new(),
613 default_model: None,
614 default_fast_model: None,
615 recommended_models: Vec::new(),
616 }
617 }
618
619 pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
620 let http_client = self.http_client.clone();
621 let token_provider = self.token_provider.clone();
622 cx.spawn(async move |this, cx| {
623 let auth_context = token_provider.auth_context(cx);
624 let response =
625 Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
626 this.update(cx, |this, cx| {
627 this.update_models(response);
628 cx.notify();
629 })
630 })
631 }
632
633 async fn fetch_models_request(
634 http_client: &HttpClientWithUrl,
635 token_provider: &TP,
636 auth_context: TP::AuthContext,
637 ) -> Result<ListModelsResponse> {
638 let token = token_provider.acquire_token(auth_context).await?;
639
640 let request = http_client::Request::builder()
641 .method(Method::GET)
642 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
643 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
644 .header("Authorization", format!("Bearer {token}"))
645 .body(AsyncBody::empty())?;
646 let mut response = http_client
647 .send(request)
648 .await
649 .context("failed to send list models request")?;
650
651 if response.status().is_success() {
652 let mut body = String::new();
653 response.body_mut().read_to_string(&mut body).await?;
654 Ok(serde_json::from_str(&body)?)
655 } else {
656 let mut body = String::new();
657 response.body_mut().read_to_string(&mut body).await?;
658 anyhow::bail!(
659 "error listing models.\nStatus: {:?}\nBody: {body}",
660 response.status(),
661 );
662 }
663 }
664
665 pub fn update_models(&mut self, response: ListModelsResponse) {
666 let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
667
668 self.default_model = models
669 .iter()
670 .find(|model| {
671 response
672 .default_model
673 .as_ref()
674 .is_some_and(|default_model_id| &model.id == default_model_id)
675 })
676 .cloned();
677 self.default_fast_model = models
678 .iter()
679 .find(|model| {
680 response
681 .default_fast_model
682 .as_ref()
683 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
684 })
685 .cloned();
686 self.recommended_models = response
687 .recommended_models
688 .iter()
689 .filter_map(|id| models.iter().find(|model| &model.id == id))
690 .cloned()
691 .collect();
692 self.models = models;
693 }
694
695 pub fn create_model(
696 &self,
697 model: &Arc<cloud_llm_client::LanguageModel>,
698 ) -> Arc<dyn LanguageModel> {
699 Arc::new(CloudLanguageModel::<TP> {
700 id: LanguageModelId::from(model.id.0.to_string()),
701 model: model.clone(),
702 token_provider: self.token_provider.clone(),
703 http_client: self.http_client.clone(),
704 app_version: self.app_version.clone(),
705 request_limiter: RateLimiter::new(4),
706 })
707 }
708
709 pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
710 &self.models
711 }
712
713 pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
714 self.default_model.as_ref()
715 }
716
717 pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
718 self.default_fast_model.as_ref()
719 }
720
721 pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
722 &self.recommended_models
723 }
724}
725
726pub fn map_cloud_completion_events<T, F>(
727 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
728 provider: &LanguageModelProviderName,
729 mut map_callback: F,
730) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
731where
732 T: DeserializeOwned + 'static,
733 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
734 + Send
735 + 'static,
736{
737 let provider = provider.clone();
738 let mut stream = stream.fuse();
739
740 let mut saw_stream_ended = false;
741
742 let mut done = false;
743 let mut pending = VecDeque::new();
744
745 stream::poll_fn(move |cx| {
746 loop {
747 if let Some(item) = pending.pop_front() {
748 return Poll::Ready(Some(item));
749 }
750
751 if done {
752 return Poll::Ready(None);
753 }
754
755 match stream.poll_next_unpin(cx) {
756 Poll::Ready(Some(event)) => {
757 let items = match event {
758 Err(error) => {
759 vec![Err(LanguageModelCompletionError::from(error))]
760 }
761 Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
762 saw_stream_ended = true;
763 vec![]
764 }
765 Ok(CompletionEvent::Status(status)) => {
766 LanguageModelCompletionEvent::from_completion_request_status(
767 status,
768 provider.clone(),
769 )
770 .transpose()
771 .map(|event| vec![event])
772 .unwrap_or_default()
773 }
774 Ok(CompletionEvent::Event(event)) => map_callback(event),
775 };
776 pending.extend(items);
777 }
778 Poll::Ready(None) => {
779 done = true;
780
781 if !saw_stream_ended {
782 return Poll::Ready(Some(Err(
783 LanguageModelCompletionError::StreamEndedUnexpectedly {
784 provider: provider.clone(),
785 },
786 )));
787 }
788 }
789 Poll::Pending => return Poll::Pending,
790 }
791 }
792 })
793 .boxed()
794}
795
796pub fn provider_name(
797 provider: &cloud_llm_client::LanguageModelProvider,
798) -> LanguageModelProviderName {
799 match provider {
800 cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
801 cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
802 cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
803 cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
804 }
805}
806
807pub fn response_lines<T: DeserializeOwned>(
808 response: Response<AsyncBody>,
809 includes_status_messages: bool,
810) -> impl Stream<Item = Result<CompletionEvent<T>>> {
811 futures::stream::try_unfold(
812 (String::new(), BufReader::new(response.into_body())),
813 move |(mut line, mut body)| async move {
814 match body.read_line(&mut line).await {
815 Ok(0) => Ok(None),
816 Ok(_) => {
817 let event = if includes_status_messages {
818 serde_json::from_str::<CompletionEvent<T>>(&line)?
819 } else {
820 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
821 };
822
823 line.clear();
824 Ok(Some((event, (line, body))))
825 }
826 Err(e) => Err(e.into()),
827 }
828 },
829 )
830}
831
832#[cfg(test)]
833mod tests {
834 use super::*;
835 use http_client::http::{HeaderMap, StatusCode};
836 use language_model::LanguageModelCompletionError;
837
838 #[test]
839 fn test_api_error_conversion_with_upstream_http_error() {
840 // upstream_http_error with 503 status should become ServerOverloaded
841 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
842
843 let api_error = ApiError {
844 status: StatusCode::INTERNAL_SERVER_ERROR,
845 body: error_body.to_string(),
846 headers: HeaderMap::new(),
847 };
848
849 let completion_error: LanguageModelCompletionError = api_error.into();
850
851 match completion_error {
852 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
853 assert_eq!(
854 message,
855 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
856 );
857 }
858 _ => panic!(
859 "Expected UpstreamProviderError for upstream 503, got: {:?}",
860 completion_error
861 ),
862 }
863
864 // upstream_http_error with 500 status should become ApiInternalServerError
865 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
866
867 let api_error = ApiError {
868 status: StatusCode::INTERNAL_SERVER_ERROR,
869 body: error_body.to_string(),
870 headers: HeaderMap::new(),
871 };
872
873 let completion_error: LanguageModelCompletionError = api_error.into();
874
875 match completion_error {
876 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
877 assert_eq!(
878 message,
879 "Received an error from the OpenAI API: internal server error"
880 );
881 }
882 _ => panic!(
883 "Expected UpstreamProviderError for upstream 500, got: {:?}",
884 completion_error
885 ),
886 }
887
888 // upstream_http_error with 429 status should become RateLimitExceeded
889 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
890
891 let api_error = ApiError {
892 status: StatusCode::INTERNAL_SERVER_ERROR,
893 body: error_body.to_string(),
894 headers: HeaderMap::new(),
895 };
896
897 let completion_error: LanguageModelCompletionError = api_error.into();
898
899 match completion_error {
900 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
901 assert_eq!(
902 message,
903 "Received an error from the Google API: rate limit exceeded"
904 );
905 }
906 _ => panic!(
907 "Expected UpstreamProviderError for upstream 429, got: {:?}",
908 completion_error
909 ),
910 }
911
912 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
913 let error_body = "Regular internal server error";
914
915 let api_error = ApiError {
916 status: StatusCode::INTERNAL_SERVER_ERROR,
917 body: error_body.to_string(),
918 headers: HeaderMap::new(),
919 };
920
921 let completion_error: LanguageModelCompletionError = api_error.into();
922
923 match completion_error {
924 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
925 assert_eq!(provider, PROVIDER_NAME);
926 assert_eq!(message, "Regular internal server error");
927 }
928 _ => panic!(
929 "Expected ApiInternalServerError for regular 500, got: {:?}",
930 completion_error
931 ),
932 }
933
934 // upstream_http_429 format should be converted to UpstreamProviderError
935 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
936
937 let api_error = ApiError {
938 status: StatusCode::INTERNAL_SERVER_ERROR,
939 body: error_body.to_string(),
940 headers: HeaderMap::new(),
941 };
942
943 let completion_error: LanguageModelCompletionError = api_error.into();
944
945 match completion_error {
946 LanguageModelCompletionError::UpstreamProviderError {
947 message,
948 status,
949 retry_after,
950 } => {
951 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
952 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
953 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
954 }
955 _ => panic!(
956 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
957 completion_error
958 ),
959 }
960
961 // Invalid JSON in error body should fall back to regular error handling
962 let error_body = "Not JSON at all";
963
964 let api_error = ApiError {
965 status: StatusCode::INTERNAL_SERVER_ERROR,
966 body: error_body.to_string(),
967 headers: HeaderMap::new(),
968 };
969
970 let completion_error: LanguageModelCompletionError = api_error.into();
971
972 match completion_error {
973 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
974 assert_eq!(provider, PROVIDER_NAME);
975 }
976 _ => panic!(
977 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
978 completion_error
979 ),
980 }
981 }
982}