1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use cloud_llm_client::{
4 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
5 CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
6 CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
7 OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
8 ZED_VERSION_HEADER_NAME,
9};
10use futures::{
11 AsyncBufReadExt, FutureExt, Stream, StreamExt,
12 future::BoxFuture,
13 stream::{self, BoxStream},
14};
15use google_ai::GoogleModelMode;
16use gpui::{App, AppContext, AsyncApp, Context, Task};
17use http_client::http::{HeaderMap, HeaderValue};
18use http_client::{
19 AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
20};
21use language_model::{
22 ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
23 LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
24 LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
25 LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
26 LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
27 OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
28 ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
29};
30
31use schemars::JsonSchema;
32use semver::Version;
33use serde::{Deserialize, Serialize, de::DeserializeOwned};
34use smol::io::{AsyncReadExt, BufReader};
35use std::collections::VecDeque;
36use std::pin::Pin;
37use std::str::FromStr;
38use std::sync::Arc;
39use std::task::Poll;
40use std::time::Duration;
41use thiserror::Error;
42
43use anthropic::completion::{
44 AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
45};
46use google_ai::completion::{GoogleEventMapper, into_google};
47use open_ai::completion::{
48 OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
49 into_open_ai_response,
50};
51use x_ai::completion::count_xai_tokens;
52
53const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
54const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
55
56/// Trait for acquiring and refreshing LLM authentication tokens.
57pub trait CloudLlmTokenProvider: Send + Sync {
58 type AuthContext: Clone + Send + 'static;
59
60 fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
61 fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
62 fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
63}
64
65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
66#[serde(tag = "type", rename_all = "lowercase")]
67pub enum ModelMode {
68 #[default]
69 Default,
70 Thinking {
71 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
72 budget_tokens: Option<u32>,
73 },
74}
75
76impl From<ModelMode> for AnthropicModelMode {
77 fn from(value: ModelMode) -> Self {
78 match value {
79 ModelMode::Default => AnthropicModelMode::Default,
80 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
81 }
82 }
83}
84
85pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
86 pub id: LanguageModelId,
87 pub model: Arc<cloud_llm_client::LanguageModel>,
88 pub token_provider: Arc<TP>,
89 pub http_client: Arc<HttpClientWithUrl>,
90 pub app_version: Option<Version>,
91 pub request_limiter: RateLimiter,
92}
93
94pub struct PerformLlmCompletionResponse {
95 pub response: Response<AsyncBody>,
96 pub includes_status_messages: bool,
97}
98
99impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
100 pub async fn perform_llm_completion(
101 http_client: &HttpClientWithUrl,
102 token_provider: &TP,
103 auth_context: TP::AuthContext,
104 app_version: Option<Version>,
105 body: CompletionBody,
106 ) -> Result<PerformLlmCompletionResponse> {
107 let mut token = token_provider.acquire_token(auth_context.clone()).await?;
108 let mut refreshed_token = false;
109
110 loop {
111 let request = http_client::Request::builder()
112 .method(Method::POST)
113 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
114 .when_some(app_version.as_ref(), |builder, app_version| {
115 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
116 })
117 .header("Content-Type", "application/json")
118 .header("Authorization", format!("Bearer {token}"))
119 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
120 .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
121 .body(serde_json::to_string(&body)?.into())?;
122
123 let mut response = http_client.send(request).await?;
124 let status = response.status();
125 if status.is_success() {
126 let includes_status_messages = response
127 .headers()
128 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
129 .is_some();
130
131 return Ok(PerformLlmCompletionResponse {
132 response,
133 includes_status_messages,
134 });
135 }
136
137 if !refreshed_token && needs_llm_token_refresh(&response) {
138 token = token_provider.refresh_token(auth_context.clone()).await?;
139 refreshed_token = true;
140 continue;
141 }
142
143 if status == StatusCode::PAYMENT_REQUIRED {
144 return Err(anyhow!(PaymentRequiredError));
145 }
146
147 let mut body = String::new();
148 let headers = response.headers().clone();
149 response.body_mut().read_to_string(&mut body).await?;
150 return Err(anyhow!(ApiError {
151 status,
152 body,
153 headers
154 }));
155 }
156 }
157}
158
159fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
160 response
161 .headers()
162 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
163 .is_some()
164 || response
165 .headers()
166 .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
167 .is_some()
168}
169
170#[derive(Debug, Error)]
171#[error("cloud language model request failed with status {status}: {body}")]
172struct ApiError {
173 status: StatusCode,
174 body: String,
175 headers: HeaderMap<HeaderValue>,
176}
177
178/// Represents error responses from Zed's cloud API.
179///
180/// Example JSON for an upstream HTTP error:
181/// ```json
182/// {
183/// "code": "upstream_http_error",
184/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
185/// "upstream_status": 503
186/// }
187/// ```
188#[derive(Debug, serde::Deserialize)]
189struct CloudApiError {
190 code: String,
191 message: String,
192 #[serde(default)]
193 #[serde(deserialize_with = "deserialize_optional_status_code")]
194 upstream_status: Option<StatusCode>,
195 #[serde(default)]
196 retry_after: Option<f64>,
197}
198
199fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
200where
201 D: serde::Deserializer<'de>,
202{
203 let opt: Option<u16> = Option::deserialize(deserializer)?;
204 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
205}
206
207impl From<ApiError> for LanguageModelCompletionError {
208 fn from(error: ApiError) -> Self {
209 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
210 if cloud_error.code.starts_with("upstream_http_") {
211 let status = if let Some(status) = cloud_error.upstream_status {
212 status
213 } else if cloud_error.code.ends_with("_error") {
214 error.status
215 } else {
216 // If there's a status code in the code string (e.g. "upstream_http_429")
217 // then use that; otherwise, see if the JSON contains a status code.
218 cloud_error
219 .code
220 .strip_prefix("upstream_http_")
221 .and_then(|code_str| code_str.parse::<u16>().ok())
222 .and_then(|code| StatusCode::from_u16(code).ok())
223 .unwrap_or(error.status)
224 };
225
226 return LanguageModelCompletionError::UpstreamProviderError {
227 message: cloud_error.message,
228 status,
229 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
230 };
231 }
232
233 return LanguageModelCompletionError::from_http_status(
234 PROVIDER_NAME,
235 error.status,
236 cloud_error.message,
237 None,
238 );
239 }
240
241 let retry_after = None;
242 LanguageModelCompletionError::from_http_status(
243 PROVIDER_NAME,
244 error.status,
245 error.body,
246 retry_after,
247 )
248 }
249}
250
251impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
252 fn id(&self) -> LanguageModelId {
253 self.id.clone()
254 }
255
256 fn name(&self) -> LanguageModelName {
257 LanguageModelName::from(self.model.display_name.clone())
258 }
259
260 fn provider_id(&self) -> LanguageModelProviderId {
261 PROVIDER_ID
262 }
263
264 fn provider_name(&self) -> LanguageModelProviderName {
265 PROVIDER_NAME
266 }
267
268 fn upstream_provider_id(&self) -> LanguageModelProviderId {
269 use cloud_llm_client::LanguageModelProvider::*;
270 match self.model.provider {
271 Anthropic => ANTHROPIC_PROVIDER_ID,
272 OpenAi => OPEN_AI_PROVIDER_ID,
273 Google => GOOGLE_PROVIDER_ID,
274 XAi => X_AI_PROVIDER_ID,
275 }
276 }
277
278 fn upstream_provider_name(&self) -> LanguageModelProviderName {
279 use cloud_llm_client::LanguageModelProvider::*;
280 match self.model.provider {
281 Anthropic => ANTHROPIC_PROVIDER_NAME,
282 OpenAi => OPEN_AI_PROVIDER_NAME,
283 Google => GOOGLE_PROVIDER_NAME,
284 XAi => X_AI_PROVIDER_NAME,
285 }
286 }
287
288 fn is_latest(&self) -> bool {
289 self.model.is_latest
290 }
291
292 fn supports_tools(&self) -> bool {
293 self.model.supports_tools
294 }
295
296 fn supports_images(&self) -> bool {
297 self.model.supports_images
298 }
299
300 fn supports_thinking(&self) -> bool {
301 self.model.supports_thinking
302 }
303
304 fn supports_fast_mode(&self) -> bool {
305 self.model.supports_fast_mode
306 }
307
308 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
309 self.model
310 .supported_effort_levels
311 .iter()
312 .map(|effort_level| LanguageModelEffortLevel {
313 name: effort_level.name.clone().into(),
314 value: effort_level.value.clone().into(),
315 is_default: effort_level.is_default.unwrap_or(false),
316 })
317 .collect()
318 }
319
320 fn supports_streaming_tools(&self) -> bool {
321 self.model.supports_streaming_tools
322 }
323
324 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
325 match choice {
326 LanguageModelToolChoice::Auto
327 | LanguageModelToolChoice::Any
328 | LanguageModelToolChoice::None => true,
329 }
330 }
331
332 fn supports_split_token_display(&self) -> bool {
333 use cloud_llm_client::LanguageModelProvider::*;
334 matches!(self.model.provider, OpenAi | XAi)
335 }
336
337 fn telemetry_id(&self) -> String {
338 format!("zed.dev/{}", self.model.id)
339 }
340
341 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
342 match self.model.provider {
343 cloud_llm_client::LanguageModelProvider::Anthropic
344 | cloud_llm_client::LanguageModelProvider::OpenAi => {
345 LanguageModelToolSchemaFormat::JsonSchema
346 }
347 cloud_llm_client::LanguageModelProvider::Google
348 | cloud_llm_client::LanguageModelProvider::XAi => {
349 LanguageModelToolSchemaFormat::JsonSchemaSubset
350 }
351 }
352 }
353
354 fn max_token_count(&self) -> u64 {
355 self.model.max_token_count as u64
356 }
357
358 fn max_output_tokens(&self) -> Option<u64> {
359 Some(self.model.max_output_tokens as u64)
360 }
361
362 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
363 match &self.model.provider {
364 cloud_llm_client::LanguageModelProvider::Anthropic => {
365 Some(LanguageModelCacheConfiguration {
366 min_total_token: 2_048,
367 should_speculate: true,
368 max_cache_anchors: 4,
369 })
370 }
371 cloud_llm_client::LanguageModelProvider::OpenAi
372 | cloud_llm_client::LanguageModelProvider::XAi
373 | cloud_llm_client::LanguageModelProvider::Google => None,
374 }
375 }
376
377 fn count_tokens(
378 &self,
379 request: LanguageModelRequest,
380 cx: &App,
381 ) -> BoxFuture<'static, Result<u64>> {
382 match self.model.provider {
383 cloud_llm_client::LanguageModelProvider::Anthropic => cx
384 .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
385 .boxed(),
386 cloud_llm_client::LanguageModelProvider::OpenAi => {
387 let model = match open_ai::Model::from_id(&self.model.id.0) {
388 Ok(model) => model,
389 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
390 };
391 cx.background_spawn(async move { count_open_ai_tokens(request, model) })
392 .boxed()
393 }
394 cloud_llm_client::LanguageModelProvider::XAi => {
395 let model = match x_ai::Model::from_id(&self.model.id.0) {
396 Ok(model) => model,
397 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
398 };
399 cx.background_spawn(async move { count_xai_tokens(request, model) })
400 .boxed()
401 }
402 cloud_llm_client::LanguageModelProvider::Google => {
403 let http_client = self.http_client.clone();
404 let token_provider = self.token_provider.clone();
405 let model_id = self.model.id.to_string();
406 let generate_content_request =
407 into_google(request, model_id.clone(), GoogleModelMode::Default);
408 let auth_context = token_provider.auth_context(cx);
409 async move {
410 let token = token_provider.acquire_token(auth_context).await?;
411
412 let request_body = CountTokensBody {
413 provider: cloud_llm_client::LanguageModelProvider::Google,
414 model: model_id,
415 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
416 generate_content_request,
417 })?,
418 };
419 let request = http_client::Request::builder()
420 .method(Method::POST)
421 .uri(
422 http_client
423 .build_zed_llm_url("/count_tokens", &[])?
424 .as_ref(),
425 )
426 .header("Content-Type", "application/json")
427 .header("Authorization", format!("Bearer {token}"))
428 .body(serde_json::to_string(&request_body)?.into())?;
429 let mut response = http_client.send(request).await?;
430 let status = response.status();
431 let headers = response.headers().clone();
432 let mut response_body = String::new();
433 response
434 .body_mut()
435 .read_to_string(&mut response_body)
436 .await?;
437
438 if status.is_success() {
439 let response_body: CountTokensResponse =
440 serde_json::from_str(&response_body)?;
441
442 Ok(response_body.tokens as u64)
443 } else {
444 Err(anyhow!(ApiError {
445 status,
446 body: response_body,
447 headers
448 }))
449 }
450 }
451 .boxed()
452 }
453 }
454 }
455
456 fn stream_completion(
457 &self,
458 request: LanguageModelRequest,
459 cx: &AsyncApp,
460 ) -> BoxFuture<
461 'static,
462 Result<
463 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
464 LanguageModelCompletionError,
465 >,
466 > {
467 let thread_id = request.thread_id.clone();
468 let prompt_id = request.prompt_id.clone();
469 let app_version = self.app_version.clone();
470 let thinking_allowed = request.thinking_allowed;
471 let enable_thinking = thinking_allowed && self.model.supports_thinking;
472 let provider_name = provider_name(&self.model.provider);
473 match self.model.provider {
474 cloud_llm_client::LanguageModelProvider::Anthropic => {
475 let effort = request
476 .thinking_effort
477 .as_ref()
478 .and_then(|effort| anthropic::Effort::from_str(effort).ok());
479
480 let mut request = into_anthropic(
481 request,
482 self.model.id.to_string(),
483 1.0,
484 self.model.max_output_tokens as u64,
485 if enable_thinking {
486 AnthropicModelMode::Thinking {
487 budget_tokens: Some(4_096),
488 }
489 } else {
490 AnthropicModelMode::Default
491 },
492 );
493
494 if enable_thinking && effort.is_some() {
495 request.thinking = Some(anthropic::Thinking::Adaptive);
496 request.output_config = Some(anthropic::OutputConfig { effort });
497 }
498
499 if !self.model.supports_fast_mode {
500 request.speed = None;
501 }
502
503 let http_client = self.http_client.clone();
504 let token_provider = self.token_provider.clone();
505 let auth_context = token_provider.auth_context(cx);
506 let future = self.request_limiter.stream(async move {
507 let PerformLlmCompletionResponse {
508 response,
509 includes_status_messages,
510 } = Self::perform_llm_completion(
511 &http_client,
512 &*token_provider,
513 auth_context,
514 app_version,
515 CompletionBody {
516 thread_id,
517 prompt_id,
518 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
519 model: request.model.clone(),
520 provider_request: serde_json::to_value(&request)
521 .map_err(|e| anyhow!(e))?,
522 },
523 )
524 .await
525 .map_err(|err| match err.downcast::<ApiError>() {
526 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
527 Err(err) => anyhow!(err),
528 })?;
529
530 let mut mapper = AnthropicEventMapper::new();
531 Ok(map_cloud_completion_events(
532 Box::pin(response_lines(response, includes_status_messages)),
533 &provider_name,
534 move |event| mapper.map_event(event),
535 ))
536 });
537 async move { Ok(future.await?.boxed()) }.boxed()
538 }
539 cloud_llm_client::LanguageModelProvider::OpenAi => {
540 let http_client = self.http_client.clone();
541 let token_provider = self.token_provider.clone();
542 let effort = request
543 .thinking_effort
544 .as_ref()
545 .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
546
547 let mut request = into_open_ai_response(
548 request,
549 &self.model.id.0,
550 self.model.supports_parallel_tool_calls,
551 true,
552 None,
553 None,
554 );
555
556 if enable_thinking && let Some(effort) = effort {
557 request.reasoning = Some(open_ai::responses::ReasoningConfig {
558 effort,
559 summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
560 });
561 }
562
563 let auth_context = token_provider.auth_context(cx);
564 let future = self.request_limiter.stream(async move {
565 let PerformLlmCompletionResponse {
566 response,
567 includes_status_messages,
568 } = Self::perform_llm_completion(
569 &http_client,
570 &*token_provider,
571 auth_context,
572 app_version,
573 CompletionBody {
574 thread_id,
575 prompt_id,
576 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
577 model: request.model.clone(),
578 provider_request: serde_json::to_value(&request)
579 .map_err(|e| anyhow!(e))?,
580 },
581 )
582 .await?;
583
584 let mut mapper = OpenAiResponseEventMapper::new();
585 Ok(map_cloud_completion_events(
586 Box::pin(response_lines(response, includes_status_messages)),
587 &provider_name,
588 move |event| mapper.map_event(event),
589 ))
590 });
591 async move { Ok(future.await?.boxed()) }.boxed()
592 }
593 cloud_llm_client::LanguageModelProvider::XAi => {
594 let http_client = self.http_client.clone();
595 let token_provider = self.token_provider.clone();
596 let request = into_open_ai(
597 request,
598 &self.model.id.0,
599 self.model.supports_parallel_tool_calls,
600 false,
601 None,
602 None,
603 );
604 let auth_context = token_provider.auth_context(cx);
605 let future = self.request_limiter.stream(async move {
606 let PerformLlmCompletionResponse {
607 response,
608 includes_status_messages,
609 } = Self::perform_llm_completion(
610 &http_client,
611 &*token_provider,
612 auth_context,
613 app_version,
614 CompletionBody {
615 thread_id,
616 prompt_id,
617 provider: cloud_llm_client::LanguageModelProvider::XAi,
618 model: request.model.clone(),
619 provider_request: serde_json::to_value(&request)
620 .map_err(|e| anyhow!(e))?,
621 },
622 )
623 .await?;
624
625 let mut mapper = OpenAiEventMapper::new();
626 Ok(map_cloud_completion_events(
627 Box::pin(response_lines(response, includes_status_messages)),
628 &provider_name,
629 move |event| mapper.map_event(event),
630 ))
631 });
632 async move { Ok(future.await?.boxed()) }.boxed()
633 }
634 cloud_llm_client::LanguageModelProvider::Google => {
635 let http_client = self.http_client.clone();
636 let token_provider = self.token_provider.clone();
637 let request =
638 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
639 let auth_context = token_provider.auth_context(cx);
640 let future = self.request_limiter.stream(async move {
641 let PerformLlmCompletionResponse {
642 response,
643 includes_status_messages,
644 } = Self::perform_llm_completion(
645 &http_client,
646 &*token_provider,
647 auth_context,
648 app_version,
649 CompletionBody {
650 thread_id,
651 prompt_id,
652 provider: cloud_llm_client::LanguageModelProvider::Google,
653 model: request.model.model_id.clone(),
654 provider_request: serde_json::to_value(&request)
655 .map_err(|e| anyhow!(e))?,
656 },
657 )
658 .await?;
659
660 let mut mapper = GoogleEventMapper::new();
661 Ok(map_cloud_completion_events(
662 Box::pin(response_lines(response, includes_status_messages)),
663 &provider_name,
664 move |event| mapper.map_event(event),
665 ))
666 });
667 async move { Ok(future.await?.boxed()) }.boxed()
668 }
669 }
670 }
671}
672
673pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
674 token_provider: Arc<TP>,
675 http_client: Arc<HttpClientWithUrl>,
676 app_version: Option<Version>,
677 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
678 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
679 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
680 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
681}
682
683impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
684 pub fn new(
685 token_provider: Arc<TP>,
686 http_client: Arc<HttpClientWithUrl>,
687 app_version: Option<Version>,
688 ) -> Self {
689 Self {
690 token_provider,
691 http_client,
692 app_version,
693 models: Vec::new(),
694 default_model: None,
695 default_fast_model: None,
696 recommended_models: Vec::new(),
697 }
698 }
699
700 pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
701 let http_client = self.http_client.clone();
702 let token_provider = self.token_provider.clone();
703 cx.spawn(async move |this, cx| {
704 let auth_context = token_provider.auth_context(cx);
705 let response =
706 Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
707 this.update(cx, |this, cx| {
708 this.update_models(response);
709 cx.notify();
710 })
711 })
712 }
713
714 async fn fetch_models_request(
715 http_client: &HttpClientWithUrl,
716 token_provider: &TP,
717 auth_context: TP::AuthContext,
718 ) -> Result<ListModelsResponse> {
719 let token = token_provider.acquire_token(auth_context).await?;
720
721 let request = http_client::Request::builder()
722 .method(Method::GET)
723 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
724 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
725 .header("Authorization", format!("Bearer {token}"))
726 .body(AsyncBody::empty())?;
727 let mut response = http_client
728 .send(request)
729 .await
730 .context("failed to send list models request")?;
731
732 if response.status().is_success() {
733 let mut body = String::new();
734 response.body_mut().read_to_string(&mut body).await?;
735 Ok(serde_json::from_str(&body)?)
736 } else {
737 let mut body = String::new();
738 response.body_mut().read_to_string(&mut body).await?;
739 anyhow::bail!(
740 "error listing models.\nStatus: {:?}\nBody: {body}",
741 response.status(),
742 );
743 }
744 }
745
746 pub fn update_models(&mut self, response: ListModelsResponse) {
747 let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
748
749 self.default_model = models
750 .iter()
751 .find(|model| {
752 response
753 .default_model
754 .as_ref()
755 .is_some_and(|default_model_id| &model.id == default_model_id)
756 })
757 .cloned();
758 self.default_fast_model = models
759 .iter()
760 .find(|model| {
761 response
762 .default_fast_model
763 .as_ref()
764 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
765 })
766 .cloned();
767 self.recommended_models = response
768 .recommended_models
769 .iter()
770 .filter_map(|id| models.iter().find(|model| &model.id == id))
771 .cloned()
772 .collect();
773 self.models = models;
774 }
775
776 pub fn create_model(
777 &self,
778 model: &Arc<cloud_llm_client::LanguageModel>,
779 ) -> Arc<dyn LanguageModel> {
780 Arc::new(CloudLanguageModel::<TP> {
781 id: LanguageModelId::from(model.id.0.to_string()),
782 model: model.clone(),
783 token_provider: self.token_provider.clone(),
784 http_client: self.http_client.clone(),
785 app_version: self.app_version.clone(),
786 request_limiter: RateLimiter::new(4),
787 })
788 }
789
790 pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
791 &self.models
792 }
793
794 pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
795 self.default_model.as_ref()
796 }
797
798 pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
799 self.default_fast_model.as_ref()
800 }
801
802 pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
803 &self.recommended_models
804 }
805}
806
807pub fn map_cloud_completion_events<T, F>(
808 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
809 provider: &LanguageModelProviderName,
810 mut map_callback: F,
811) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
812where
813 T: DeserializeOwned + 'static,
814 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
815 + Send
816 + 'static,
817{
818 let provider = provider.clone();
819 let mut stream = stream.fuse();
820
821 let mut saw_stream_ended = false;
822
823 let mut done = false;
824 let mut pending = VecDeque::new();
825
826 stream::poll_fn(move |cx| {
827 loop {
828 if let Some(item) = pending.pop_front() {
829 return Poll::Ready(Some(item));
830 }
831
832 if done {
833 return Poll::Ready(None);
834 }
835
836 match stream.poll_next_unpin(cx) {
837 Poll::Ready(Some(event)) => {
838 let items = match event {
839 Err(error) => {
840 vec![Err(LanguageModelCompletionError::from(error))]
841 }
842 Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
843 saw_stream_ended = true;
844 vec![]
845 }
846 Ok(CompletionEvent::Status(status)) => {
847 LanguageModelCompletionEvent::from_completion_request_status(
848 status,
849 provider.clone(),
850 )
851 .transpose()
852 .map(|event| vec![event])
853 .unwrap_or_default()
854 }
855 Ok(CompletionEvent::Event(event)) => map_callback(event),
856 };
857 pending.extend(items);
858 }
859 Poll::Ready(None) => {
860 done = true;
861
862 if !saw_stream_ended {
863 return Poll::Ready(Some(Err(
864 LanguageModelCompletionError::StreamEndedUnexpectedly {
865 provider: provider.clone(),
866 },
867 )));
868 }
869 }
870 Poll::Pending => return Poll::Pending,
871 }
872 }
873 })
874 .boxed()
875}
876
877pub fn provider_name(
878 provider: &cloud_llm_client::LanguageModelProvider,
879) -> LanguageModelProviderName {
880 match provider {
881 cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
882 cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
883 cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
884 cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
885 }
886}
887
888pub fn response_lines<T: DeserializeOwned>(
889 response: Response<AsyncBody>,
890 includes_status_messages: bool,
891) -> impl Stream<Item = Result<CompletionEvent<T>>> {
892 futures::stream::try_unfold(
893 (String::new(), BufReader::new(response.into_body())),
894 move |(mut line, mut body)| async move {
895 match body.read_line(&mut line).await {
896 Ok(0) => Ok(None),
897 Ok(_) => {
898 let event = if includes_status_messages {
899 serde_json::from_str::<CompletionEvent<T>>(&line)?
900 } else {
901 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
902 };
903
904 line.clear();
905 Ok(Some((event, (line, body))))
906 }
907 Err(e) => Err(e.into()),
908 }
909 },
910 )
911}
912
913#[cfg(test)]
914mod tests {
915 use super::*;
916 use http_client::http::{HeaderMap, StatusCode};
917 use language_model::LanguageModelCompletionError;
918
919 #[test]
920 fn test_api_error_conversion_with_upstream_http_error() {
921 // upstream_http_error with 503 status should become ServerOverloaded
922 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
923
924 let api_error = ApiError {
925 status: StatusCode::INTERNAL_SERVER_ERROR,
926 body: error_body.to_string(),
927 headers: HeaderMap::new(),
928 };
929
930 let completion_error: LanguageModelCompletionError = api_error.into();
931
932 match completion_error {
933 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
934 assert_eq!(
935 message,
936 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
937 );
938 }
939 _ => panic!(
940 "Expected UpstreamProviderError for upstream 503, got: {:?}",
941 completion_error
942 ),
943 }
944
945 // upstream_http_error with 500 status should become ApiInternalServerError
946 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
947
948 let api_error = ApiError {
949 status: StatusCode::INTERNAL_SERVER_ERROR,
950 body: error_body.to_string(),
951 headers: HeaderMap::new(),
952 };
953
954 let completion_error: LanguageModelCompletionError = api_error.into();
955
956 match completion_error {
957 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
958 assert_eq!(
959 message,
960 "Received an error from the OpenAI API: internal server error"
961 );
962 }
963 _ => panic!(
964 "Expected UpstreamProviderError for upstream 500, got: {:?}",
965 completion_error
966 ),
967 }
968
969 // upstream_http_error with 429 status should become RateLimitExceeded
970 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
971
972 let api_error = ApiError {
973 status: StatusCode::INTERNAL_SERVER_ERROR,
974 body: error_body.to_string(),
975 headers: HeaderMap::new(),
976 };
977
978 let completion_error: LanguageModelCompletionError = api_error.into();
979
980 match completion_error {
981 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
982 assert_eq!(
983 message,
984 "Received an error from the Google API: rate limit exceeded"
985 );
986 }
987 _ => panic!(
988 "Expected UpstreamProviderError for upstream 429, got: {:?}",
989 completion_error
990 ),
991 }
992
993 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
994 let error_body = "Regular internal server error";
995
996 let api_error = ApiError {
997 status: StatusCode::INTERNAL_SERVER_ERROR,
998 body: error_body.to_string(),
999 headers: HeaderMap::new(),
1000 };
1001
1002 let completion_error: LanguageModelCompletionError = api_error.into();
1003
1004 match completion_error {
1005 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1006 assert_eq!(provider, PROVIDER_NAME);
1007 assert_eq!(message, "Regular internal server error");
1008 }
1009 _ => panic!(
1010 "Expected ApiInternalServerError for regular 500, got: {:?}",
1011 completion_error
1012 ),
1013 }
1014
1015 // upstream_http_429 format should be converted to UpstreamProviderError
1016 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1017
1018 let api_error = ApiError {
1019 status: StatusCode::INTERNAL_SERVER_ERROR,
1020 body: error_body.to_string(),
1021 headers: HeaderMap::new(),
1022 };
1023
1024 let completion_error: LanguageModelCompletionError = api_error.into();
1025
1026 match completion_error {
1027 LanguageModelCompletionError::UpstreamProviderError {
1028 message,
1029 status,
1030 retry_after,
1031 } => {
1032 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1033 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1034 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1035 }
1036 _ => panic!(
1037 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1038 completion_error
1039 ),
1040 }
1041
1042 // Invalid JSON in error body should fall back to regular error handling
1043 let error_body = "Not JSON at all";
1044
1045 let api_error = ApiError {
1046 status: StatusCode::INTERNAL_SERVER_ERROR,
1047 body: error_body.to_string(),
1048 headers: HeaderMap::new(),
1049 };
1050
1051 let completion_error: LanguageModelCompletionError = api_error.into();
1052
1053 match completion_error {
1054 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1055 assert_eq!(provider, PROVIDER_NAME);
1056 }
1057 _ => panic!(
1058 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1059 completion_error
1060 ),
1061 }
1062 }
1063}