1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use cloud_llm_client::{
4 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
5 CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
6 CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
7 OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
8 ZED_VERSION_HEADER_NAME,
9};
10use futures::{
11 AsyncBufReadExt, FutureExt, Stream, StreamExt,
12 future::BoxFuture,
13 stream::{self, BoxStream},
14};
15use google_ai::GoogleModelMode;
16use gpui::{App, AppContext, AsyncApp, Context, Task};
17use http_client::http::{HeaderMap, HeaderValue};
18use http_client::{
19 AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
20};
21use language_model::{
22 ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
23 LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
24 LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
25 LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
26 LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
27 OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
28 ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
29};
30
31use schemars::JsonSchema;
32use semver::Version;
33use serde::{Deserialize, Serialize, de::DeserializeOwned};
34use smol::io::{AsyncReadExt, BufReader};
35use std::collections::VecDeque;
36use std::pin::Pin;
37use std::str::FromStr;
38use std::sync::Arc;
39use std::task::Poll;
40use std::time::Duration;
41use thiserror::Error;
42
43use anthropic::completion::{
44 AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
45};
46use google_ai::completion::{GoogleEventMapper, into_google};
47use open_ai::completion::{
48 OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
49 into_open_ai_response,
50};
51use x_ai::completion::count_xai_tokens;
52
53const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
54const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
55
56/// Trait for acquiring and refreshing LLM authentication tokens.
57pub trait CloudLlmTokenProvider: Send + Sync {
58 type AuthContext: Clone + Send + 'static;
59
60 fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
61 fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
62 fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
63}
64
65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
66#[serde(tag = "type", rename_all = "lowercase")]
67pub enum ModelMode {
68 #[default]
69 Default,
70 Thinking {
71 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
72 budget_tokens: Option<u32>,
73 },
74}
75
76impl From<ModelMode> for AnthropicModelMode {
77 fn from(value: ModelMode) -> Self {
78 match value {
79 ModelMode::Default => AnthropicModelMode::Default,
80 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
81 }
82 }
83}
84
85pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
86 pub id: LanguageModelId,
87 pub model: Arc<cloud_llm_client::LanguageModel>,
88 pub token_provider: Arc<TP>,
89 pub http_client: Arc<HttpClientWithUrl>,
90 pub app_version: Option<Version>,
91 pub request_limiter: RateLimiter,
92}
93
94pub struct PerformLlmCompletionResponse {
95 pub response: Response<AsyncBody>,
96 pub includes_status_messages: bool,
97}
98
99impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
100 pub async fn perform_llm_completion(
101 http_client: &HttpClientWithUrl,
102 token_provider: &TP,
103 auth_context: TP::AuthContext,
104 app_version: Option<Version>,
105 body: CompletionBody,
106 ) -> Result<PerformLlmCompletionResponse> {
107 let mut token = token_provider.acquire_token(auth_context.clone()).await?;
108 let mut refreshed_token = false;
109
110 loop {
111 let request = http_client::Request::builder()
112 .method(Method::POST)
113 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
114 .when_some(app_version.as_ref(), |builder, app_version| {
115 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
116 })
117 .header("Content-Type", "application/json")
118 .header("Authorization", format!("Bearer {token}"))
119 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
120 .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
121 .body(serde_json::to_string(&body)?.into())?;
122
123 let mut response = http_client.send(request).await?;
124 let status = response.status();
125 if status.is_success() {
126 let includes_status_messages = response
127 .headers()
128 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
129 .is_some();
130
131 return Ok(PerformLlmCompletionResponse {
132 response,
133 includes_status_messages,
134 });
135 }
136
137 if !refreshed_token && needs_llm_token_refresh(&response) {
138 token = token_provider.refresh_token(auth_context.clone()).await?;
139 refreshed_token = true;
140 continue;
141 }
142
143 if status == StatusCode::PAYMENT_REQUIRED {
144 return Err(anyhow!(PaymentRequiredError));
145 }
146
147 let mut body = String::new();
148 let headers = response.headers().clone();
149 response.body_mut().read_to_string(&mut body).await?;
150 return Err(anyhow!(ApiError {
151 status,
152 body,
153 headers
154 }));
155 }
156 }
157}
158
159fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
160 response
161 .headers()
162 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
163 .is_some()
164 || response
165 .headers()
166 .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
167 .is_some()
168}
169
170#[derive(Debug, Error)]
171#[error("cloud language model request failed with status {status}: {body}")]
172struct ApiError {
173 status: StatusCode,
174 body: String,
175 headers: HeaderMap<HeaderValue>,
176}
177
178/// Represents error responses from Zed's cloud API.
179///
180/// Example JSON for an upstream HTTP error:
181/// ```json
182/// {
183/// "code": "upstream_http_error",
184/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
185/// "upstream_status": 503
186/// }
187/// ```
188#[derive(Debug, serde::Deserialize)]
189struct CloudApiError {
190 code: String,
191 message: String,
192 #[serde(default)]
193 #[serde(deserialize_with = "deserialize_optional_status_code")]
194 upstream_status: Option<StatusCode>,
195 #[serde(default)]
196 retry_after: Option<f64>,
197}
198
199fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
200where
201 D: serde::Deserializer<'de>,
202{
203 let opt: Option<u16> = Option::deserialize(deserializer)?;
204 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
205}
206
207impl From<ApiError> for LanguageModelCompletionError {
208 fn from(error: ApiError) -> Self {
209 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
210 if cloud_error.code.starts_with("upstream_http_") {
211 let status = if let Some(status) = cloud_error.upstream_status {
212 status
213 } else if cloud_error.code.ends_with("_error") {
214 error.status
215 } else {
216 // If there's a status code in the code string (e.g. "upstream_http_429")
217 // then use that; otherwise, see if the JSON contains a status code.
218 cloud_error
219 .code
220 .strip_prefix("upstream_http_")
221 .and_then(|code_str| code_str.parse::<u16>().ok())
222 .and_then(|code| StatusCode::from_u16(code).ok())
223 .unwrap_or(error.status)
224 };
225
226 return LanguageModelCompletionError::UpstreamProviderError {
227 message: cloud_error.message,
228 status,
229 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
230 };
231 }
232
233 return LanguageModelCompletionError::from_http_status(
234 PROVIDER_NAME,
235 error.status,
236 cloud_error.message,
237 None,
238 );
239 }
240
241 let retry_after = None;
242 LanguageModelCompletionError::from_http_status(
243 PROVIDER_NAME,
244 error.status,
245 error.body,
246 retry_after,
247 )
248 }
249}
250
251impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
252 fn id(&self) -> LanguageModelId {
253 self.id.clone()
254 }
255
256 fn name(&self) -> LanguageModelName {
257 LanguageModelName::from(self.model.display_name.clone())
258 }
259
260 fn provider_id(&self) -> LanguageModelProviderId {
261 PROVIDER_ID
262 }
263
264 fn provider_name(&self) -> LanguageModelProviderName {
265 PROVIDER_NAME
266 }
267
268 fn upstream_provider_id(&self) -> LanguageModelProviderId {
269 use cloud_llm_client::LanguageModelProvider::*;
270 match self.model.provider {
271 Anthropic => ANTHROPIC_PROVIDER_ID,
272 OpenAi => OPEN_AI_PROVIDER_ID,
273 Google => GOOGLE_PROVIDER_ID,
274 XAi => X_AI_PROVIDER_ID,
275 }
276 }
277
278 fn upstream_provider_name(&self) -> LanguageModelProviderName {
279 use cloud_llm_client::LanguageModelProvider::*;
280 match self.model.provider {
281 Anthropic => ANTHROPIC_PROVIDER_NAME,
282 OpenAi => OPEN_AI_PROVIDER_NAME,
283 Google => GOOGLE_PROVIDER_NAME,
284 XAi => X_AI_PROVIDER_NAME,
285 }
286 }
287
288 fn is_latest(&self) -> bool {
289 self.model.is_latest
290 }
291
292 fn supports_tools(&self) -> bool {
293 self.model.supports_tools
294 }
295
296 fn supports_images(&self) -> bool {
297 self.model.supports_images
298 }
299
300 fn supports_thinking(&self) -> bool {
301 self.model.supports_thinking
302 }
303
304 fn supports_fast_mode(&self) -> bool {
305 self.model.supports_fast_mode
306 }
307
308 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
309 self.model
310 .supported_effort_levels
311 .iter()
312 .map(|effort_level| LanguageModelEffortLevel {
313 name: effort_level.name.clone().into(),
314 value: effort_level.value.clone().into(),
315 is_default: effort_level.is_default.unwrap_or(false),
316 })
317 .collect()
318 }
319
320 fn supports_streaming_tools(&self) -> bool {
321 self.model.supports_streaming_tools
322 }
323
324 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
325 match choice {
326 LanguageModelToolChoice::Auto
327 | LanguageModelToolChoice::Any
328 | LanguageModelToolChoice::None => true,
329 }
330 }
331
332 fn supports_split_token_display(&self) -> bool {
333 use cloud_llm_client::LanguageModelProvider::*;
334 matches!(self.model.provider, OpenAi | XAi)
335 }
336
337 fn telemetry_id(&self) -> String {
338 format!("zed.dev/{}", self.model.id)
339 }
340
341 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
342 match self.model.provider {
343 cloud_llm_client::LanguageModelProvider::Anthropic
344 | cloud_llm_client::LanguageModelProvider::OpenAi => {
345 LanguageModelToolSchemaFormat::JsonSchema
346 }
347 cloud_llm_client::LanguageModelProvider::Google
348 | cloud_llm_client::LanguageModelProvider::XAi => {
349 LanguageModelToolSchemaFormat::JsonSchemaSubset
350 }
351 }
352 }
353
354 fn max_token_count(&self) -> u64 {
355 self.model.max_token_count as u64
356 }
357
358 fn max_output_tokens(&self) -> Option<u64> {
359 Some(self.model.max_output_tokens as u64)
360 }
361
362 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
363 match &self.model.provider {
364 cloud_llm_client::LanguageModelProvider::Anthropic => {
365 Some(LanguageModelCacheConfiguration {
366 min_total_token: 2_048,
367 should_speculate: true,
368 max_cache_anchors: 4,
369 })
370 }
371 cloud_llm_client::LanguageModelProvider::OpenAi
372 | cloud_llm_client::LanguageModelProvider::XAi
373 | cloud_llm_client::LanguageModelProvider::Google => None,
374 }
375 }
376
377 fn count_tokens(
378 &self,
379 request: LanguageModelRequest,
380 cx: &App,
381 ) -> BoxFuture<'static, Result<u64>> {
382 match self.model.provider {
383 cloud_llm_client::LanguageModelProvider::Anthropic => cx
384 .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
385 .boxed(),
386 cloud_llm_client::LanguageModelProvider::OpenAi => {
387 let model = match open_ai::Model::from_id(&self.model.id.0) {
388 Ok(model) => model,
389 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
390 };
391 cx.background_spawn(async move { count_open_ai_tokens(request, model) })
392 .boxed()
393 }
394 cloud_llm_client::LanguageModelProvider::XAi => {
395 let model = match x_ai::Model::from_id(&self.model.id.0) {
396 Ok(model) => model,
397 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
398 };
399 cx.background_spawn(async move { count_xai_tokens(request, model) })
400 .boxed()
401 }
402 cloud_llm_client::LanguageModelProvider::Google => {
403 let http_client = self.http_client.clone();
404 let token_provider = self.token_provider.clone();
405 let model_id = self.model.id.to_string();
406 let generate_content_request =
407 into_google(request, model_id.clone(), GoogleModelMode::Default);
408 let auth_context = token_provider.auth_context(cx);
409 async move {
410 let token = token_provider.acquire_token(auth_context).await?;
411
412 let request_body = CountTokensBody {
413 provider: cloud_llm_client::LanguageModelProvider::Google,
414 model: model_id,
415 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
416 generate_content_request,
417 })?,
418 };
419 let request = http_client::Request::builder()
420 .method(Method::POST)
421 .uri(
422 http_client
423 .build_zed_llm_url("/count_tokens", &[])?
424 .as_ref(),
425 )
426 .header("Content-Type", "application/json")
427 .header("Authorization", format!("Bearer {token}"))
428 .body(serde_json::to_string(&request_body)?.into())?;
429 let mut response = http_client.send(request).await?;
430 let status = response.status();
431 let headers = response.headers().clone();
432 let mut response_body = String::new();
433 response
434 .body_mut()
435 .read_to_string(&mut response_body)
436 .await?;
437
438 if status.is_success() {
439 let response_body: CountTokensResponse =
440 serde_json::from_str(&response_body)?;
441
442 Ok(response_body.tokens as u64)
443 } else {
444 Err(anyhow!(ApiError {
445 status,
446 body: response_body,
447 headers
448 }))
449 }
450 }
451 .boxed()
452 }
453 }
454 }
455
456 fn stream_completion(
457 &self,
458 request: LanguageModelRequest,
459 cx: &AsyncApp,
460 ) -> BoxFuture<
461 'static,
462 Result<
463 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
464 LanguageModelCompletionError,
465 >,
466 > {
467 let thread_id = request.thread_id.clone();
468 let prompt_id = request.prompt_id.clone();
469 let app_version = self.app_version.clone();
470 let thinking_allowed = request.thinking_allowed;
471 let enable_thinking = thinking_allowed && self.model.supports_thinking;
472 let provider_name = provider_name(&self.model.provider);
473 match self.model.provider {
474 cloud_llm_client::LanguageModelProvider::Anthropic => {
475 let effort = request
476 .thinking_effort
477 .as_ref()
478 .and_then(|effort| anthropic::Effort::from_str(effort).ok());
479
480 let mut request = into_anthropic(
481 request,
482 self.model.id.to_string(),
483 1.0,
484 self.model.max_output_tokens as u64,
485 if enable_thinking {
486 AnthropicModelMode::Thinking {
487 budget_tokens: Some(4_096),
488 }
489 } else {
490 AnthropicModelMode::Default
491 },
492 );
493
494 if enable_thinking && effort.is_some() {
495 request.thinking = Some(anthropic::Thinking::Adaptive);
496 request.output_config = Some(anthropic::OutputConfig { effort });
497 }
498
499 if !self.model.supports_fast_mode {
500 request.speed = None;
501 }
502
503 let http_client = self.http_client.clone();
504 let token_provider = self.token_provider.clone();
505 let auth_context = token_provider.auth_context(cx);
506 let future = self.request_limiter.stream(async move {
507 let PerformLlmCompletionResponse {
508 response,
509 includes_status_messages,
510 } = Self::perform_llm_completion(
511 &http_client,
512 &*token_provider,
513 auth_context,
514 app_version,
515 CompletionBody {
516 thread_id,
517 prompt_id,
518 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
519 model: request.model.clone(),
520 provider_request: serde_json::to_value(&request)
521 .map_err(|e| anyhow!(e))?,
522 },
523 )
524 .await
525 .map_err(|err| match err.downcast::<ApiError>() {
526 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
527 Err(err) => anyhow!(err),
528 })?;
529
530 let mut mapper = AnthropicEventMapper::new();
531 Ok(map_cloud_completion_events(
532 Box::pin(response_lines(response, includes_status_messages)),
533 &provider_name,
534 move |event| mapper.map_event(event),
535 ))
536 });
537 async move { Ok(future.await?.boxed()) }.boxed()
538 }
539 cloud_llm_client::LanguageModelProvider::OpenAi => {
540 let http_client = self.http_client.clone();
541 let token_provider = self.token_provider.clone();
542 let effort = request
543 .thinking_effort
544 .as_ref()
545 .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
546
547 let mut request = into_open_ai_response(
548 request,
549 &self.model.id.0,
550 self.model.supports_parallel_tool_calls,
551 true,
552 None,
553 None,
554 );
555
556 if enable_thinking && let Some(effort) = effort {
557 request.reasoning = Some(open_ai::responses::ReasoningConfig {
558 effort,
559 summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
560 });
561 }
562
563 let auth_context = token_provider.auth_context(cx);
564 let future = self.request_limiter.stream(async move {
565 let PerformLlmCompletionResponse {
566 response,
567 includes_status_messages,
568 } = Self::perform_llm_completion(
569 &http_client,
570 &*token_provider,
571 auth_context,
572 app_version,
573 CompletionBody {
574 thread_id,
575 prompt_id,
576 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
577 model: request.model.clone(),
578 provider_request: serde_json::to_value(&request)
579 .map_err(|e| anyhow!(e))?,
580 },
581 )
582 .await?;
583
584 let mut mapper = OpenAiResponseEventMapper::new();
585 Ok(map_cloud_completion_events(
586 Box::pin(response_lines(response, includes_status_messages)),
587 &provider_name,
588 move |event| mapper.map_event(event),
589 ))
590 });
591 async move { Ok(future.await?.boxed()) }.boxed()
592 }
593 cloud_llm_client::LanguageModelProvider::XAi => {
594 let http_client = self.http_client.clone();
595 let token_provider = self.token_provider.clone();
596 let request = into_open_ai(
597 request,
598 &self.model.id.0,
599 self.model.supports_parallel_tool_calls,
600 false,
601 None,
602 None,
603 false,
604 );
605 let auth_context = token_provider.auth_context(cx);
606 let future = self.request_limiter.stream(async move {
607 let PerformLlmCompletionResponse {
608 response,
609 includes_status_messages,
610 } = Self::perform_llm_completion(
611 &http_client,
612 &*token_provider,
613 auth_context,
614 app_version,
615 CompletionBody {
616 thread_id,
617 prompt_id,
618 provider: cloud_llm_client::LanguageModelProvider::XAi,
619 model: request.model.clone(),
620 provider_request: serde_json::to_value(&request)
621 .map_err(|e| anyhow!(e))?,
622 },
623 )
624 .await?;
625
626 let mut mapper = OpenAiEventMapper::new();
627 Ok(map_cloud_completion_events(
628 Box::pin(response_lines(response, includes_status_messages)),
629 &provider_name,
630 move |event| mapper.map_event(event),
631 ))
632 });
633 async move { Ok(future.await?.boxed()) }.boxed()
634 }
635 cloud_llm_client::LanguageModelProvider::Google => {
636 let http_client = self.http_client.clone();
637 let token_provider = self.token_provider.clone();
638 let request =
639 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
640 let auth_context = token_provider.auth_context(cx);
641 let future = self.request_limiter.stream(async move {
642 let PerformLlmCompletionResponse {
643 response,
644 includes_status_messages,
645 } = Self::perform_llm_completion(
646 &http_client,
647 &*token_provider,
648 auth_context,
649 app_version,
650 CompletionBody {
651 thread_id,
652 prompt_id,
653 provider: cloud_llm_client::LanguageModelProvider::Google,
654 model: request.model.model_id.clone(),
655 provider_request: serde_json::to_value(&request)
656 .map_err(|e| anyhow!(e))?,
657 },
658 )
659 .await?;
660
661 let mut mapper = GoogleEventMapper::new();
662 Ok(map_cloud_completion_events(
663 Box::pin(response_lines(response, includes_status_messages)),
664 &provider_name,
665 move |event| mapper.map_event(event),
666 ))
667 });
668 async move { Ok(future.await?.boxed()) }.boxed()
669 }
670 }
671 }
672}
673
674pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
675 token_provider: Arc<TP>,
676 http_client: Arc<HttpClientWithUrl>,
677 app_version: Option<Version>,
678 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
679 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
680 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
681 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
682}
683
684impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
685 pub fn new(
686 token_provider: Arc<TP>,
687 http_client: Arc<HttpClientWithUrl>,
688 app_version: Option<Version>,
689 ) -> Self {
690 Self {
691 token_provider,
692 http_client,
693 app_version,
694 models: Vec::new(),
695 default_model: None,
696 default_fast_model: None,
697 recommended_models: Vec::new(),
698 }
699 }
700
701 pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
702 let http_client = self.http_client.clone();
703 let token_provider = self.token_provider.clone();
704 cx.spawn(async move |this, cx| {
705 let auth_context = token_provider.auth_context(cx);
706 let response =
707 Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
708 this.update(cx, |this, cx| {
709 this.update_models(response);
710 cx.notify();
711 })
712 })
713 }
714
715 async fn fetch_models_request(
716 http_client: &HttpClientWithUrl,
717 token_provider: &TP,
718 auth_context: TP::AuthContext,
719 ) -> Result<ListModelsResponse> {
720 let token = token_provider.acquire_token(auth_context).await?;
721
722 let request = http_client::Request::builder()
723 .method(Method::GET)
724 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
725 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
726 .header("Authorization", format!("Bearer {token}"))
727 .body(AsyncBody::empty())?;
728 let mut response = http_client
729 .send(request)
730 .await
731 .context("failed to send list models request")?;
732
733 if response.status().is_success() {
734 let mut body = String::new();
735 response.body_mut().read_to_string(&mut body).await?;
736 Ok(serde_json::from_str(&body)?)
737 } else {
738 let mut body = String::new();
739 response.body_mut().read_to_string(&mut body).await?;
740 anyhow::bail!(
741 "error listing models.\nStatus: {:?}\nBody: {body}",
742 response.status(),
743 );
744 }
745 }
746
747 pub fn update_models(&mut self, response: ListModelsResponse) {
748 let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
749
750 self.default_model = models
751 .iter()
752 .find(|model| {
753 response
754 .default_model
755 .as_ref()
756 .is_some_and(|default_model_id| &model.id == default_model_id)
757 })
758 .cloned();
759 self.default_fast_model = models
760 .iter()
761 .find(|model| {
762 response
763 .default_fast_model
764 .as_ref()
765 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
766 })
767 .cloned();
768 self.recommended_models = response
769 .recommended_models
770 .iter()
771 .filter_map(|id| models.iter().find(|model| &model.id == id))
772 .cloned()
773 .collect();
774 self.models = models;
775 }
776
777 pub fn create_model(
778 &self,
779 model: &Arc<cloud_llm_client::LanguageModel>,
780 ) -> Arc<dyn LanguageModel> {
781 Arc::new(CloudLanguageModel::<TP> {
782 id: LanguageModelId::from(model.id.0.to_string()),
783 model: model.clone(),
784 token_provider: self.token_provider.clone(),
785 http_client: self.http_client.clone(),
786 app_version: self.app_version.clone(),
787 request_limiter: RateLimiter::new(4),
788 })
789 }
790
791 pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
792 &self.models
793 }
794
795 pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
796 self.default_model.as_ref()
797 }
798
799 pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
800 self.default_fast_model.as_ref()
801 }
802
803 pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
804 &self.recommended_models
805 }
806}
807
808pub fn map_cloud_completion_events<T, F>(
809 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
810 provider: &LanguageModelProviderName,
811 mut map_callback: F,
812) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
813where
814 T: DeserializeOwned + 'static,
815 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
816 + Send
817 + 'static,
818{
819 let provider = provider.clone();
820 let mut stream = stream.fuse();
821
822 let mut saw_stream_ended = false;
823
824 let mut done = false;
825 let mut pending = VecDeque::new();
826
827 stream::poll_fn(move |cx| {
828 loop {
829 if let Some(item) = pending.pop_front() {
830 return Poll::Ready(Some(item));
831 }
832
833 if done {
834 return Poll::Ready(None);
835 }
836
837 match stream.poll_next_unpin(cx) {
838 Poll::Ready(Some(event)) => {
839 let items = match event {
840 Err(error) => {
841 vec![Err(LanguageModelCompletionError::from(error))]
842 }
843 Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
844 saw_stream_ended = true;
845 vec![]
846 }
847 Ok(CompletionEvent::Status(status)) => {
848 LanguageModelCompletionEvent::from_completion_request_status(
849 status,
850 provider.clone(),
851 )
852 .transpose()
853 .map(|event| vec![event])
854 .unwrap_or_default()
855 }
856 Ok(CompletionEvent::Event(event)) => map_callback(event),
857 };
858 pending.extend(items);
859 }
860 Poll::Ready(None) => {
861 done = true;
862
863 if !saw_stream_ended {
864 return Poll::Ready(Some(Err(
865 LanguageModelCompletionError::StreamEndedUnexpectedly {
866 provider: provider.clone(),
867 },
868 )));
869 }
870 }
871 Poll::Pending => return Poll::Pending,
872 }
873 }
874 })
875 .boxed()
876}
877
878pub fn provider_name(
879 provider: &cloud_llm_client::LanguageModelProvider,
880) -> LanguageModelProviderName {
881 match provider {
882 cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
883 cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
884 cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
885 cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
886 }
887}
888
889pub fn response_lines<T: DeserializeOwned>(
890 response: Response<AsyncBody>,
891 includes_status_messages: bool,
892) -> impl Stream<Item = Result<CompletionEvent<T>>> {
893 futures::stream::try_unfold(
894 (String::new(), BufReader::new(response.into_body())),
895 move |(mut line, mut body)| async move {
896 match body.read_line(&mut line).await {
897 Ok(0) => Ok(None),
898 Ok(_) => {
899 let event = if includes_status_messages {
900 serde_json::from_str::<CompletionEvent<T>>(&line)?
901 } else {
902 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
903 };
904
905 line.clear();
906 Ok(Some((event, (line, body))))
907 }
908 Err(e) => Err(e.into()),
909 }
910 },
911 )
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917 use http_client::http::{HeaderMap, StatusCode};
918 use language_model::LanguageModelCompletionError;
919
920 #[test]
921 fn test_api_error_conversion_with_upstream_http_error() {
922 // upstream_http_error with 503 status should become ServerOverloaded
923 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
924
925 let api_error = ApiError {
926 status: StatusCode::INTERNAL_SERVER_ERROR,
927 body: error_body.to_string(),
928 headers: HeaderMap::new(),
929 };
930
931 let completion_error: LanguageModelCompletionError = api_error.into();
932
933 match completion_error {
934 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
935 assert_eq!(
936 message,
937 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
938 );
939 }
940 _ => panic!(
941 "Expected UpstreamProviderError for upstream 503, got: {:?}",
942 completion_error
943 ),
944 }
945
946 // upstream_http_error with 500 status should become ApiInternalServerError
947 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
948
949 let api_error = ApiError {
950 status: StatusCode::INTERNAL_SERVER_ERROR,
951 body: error_body.to_string(),
952 headers: HeaderMap::new(),
953 };
954
955 let completion_error: LanguageModelCompletionError = api_error.into();
956
957 match completion_error {
958 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
959 assert_eq!(
960 message,
961 "Received an error from the OpenAI API: internal server error"
962 );
963 }
964 _ => panic!(
965 "Expected UpstreamProviderError for upstream 500, got: {:?}",
966 completion_error
967 ),
968 }
969
970 // upstream_http_error with 429 status should become RateLimitExceeded
971 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
972
973 let api_error = ApiError {
974 status: StatusCode::INTERNAL_SERVER_ERROR,
975 body: error_body.to_string(),
976 headers: HeaderMap::new(),
977 };
978
979 let completion_error: LanguageModelCompletionError = api_error.into();
980
981 match completion_error {
982 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
983 assert_eq!(
984 message,
985 "Received an error from the Google API: rate limit exceeded"
986 );
987 }
988 _ => panic!(
989 "Expected UpstreamProviderError for upstream 429, got: {:?}",
990 completion_error
991 ),
992 }
993
994 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
995 let error_body = "Regular internal server error";
996
997 let api_error = ApiError {
998 status: StatusCode::INTERNAL_SERVER_ERROR,
999 body: error_body.to_string(),
1000 headers: HeaderMap::new(),
1001 };
1002
1003 let completion_error: LanguageModelCompletionError = api_error.into();
1004
1005 match completion_error {
1006 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1007 assert_eq!(provider, PROVIDER_NAME);
1008 assert_eq!(message, "Regular internal server error");
1009 }
1010 _ => panic!(
1011 "Expected ApiInternalServerError for regular 500, got: {:?}",
1012 completion_error
1013 ),
1014 }
1015
1016 // upstream_http_429 format should be converted to UpstreamProviderError
1017 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1018
1019 let api_error = ApiError {
1020 status: StatusCode::INTERNAL_SERVER_ERROR,
1021 body: error_body.to_string(),
1022 headers: HeaderMap::new(),
1023 };
1024
1025 let completion_error: LanguageModelCompletionError = api_error.into();
1026
1027 match completion_error {
1028 LanguageModelCompletionError::UpstreamProviderError {
1029 message,
1030 status,
1031 retry_after,
1032 } => {
1033 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1034 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1035 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1036 }
1037 _ => panic!(
1038 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1039 completion_error
1040 ),
1041 }
1042
1043 // Invalid JSON in error body should fall back to regular error handling
1044 let error_body = "Not JSON at all";
1045
1046 let api_error = ApiError {
1047 status: StatusCode::INTERNAL_SERVER_ERROR,
1048 body: error_body.to_string(),
1049 headers: HeaderMap::new(),
1050 };
1051
1052 let completion_error: LanguageModelCompletionError = api_error.into();
1053
1054 match completion_error {
1055 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1056 assert_eq!(provider, PROVIDER_NAME);
1057 }
1058 _ => panic!(
1059 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1060 completion_error
1061 ),
1062 }
1063 }
1064}