1use anthropic::AnthropicModelMode;
2use anyhow::{Context as _, Result, anyhow};
3use cloud_llm_client::{
4 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
5 CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
6 CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
7 OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
8 ZED_VERSION_HEADER_NAME,
9};
10use futures::{
11 AsyncBufReadExt, FutureExt, Stream, StreamExt,
12 future::BoxFuture,
13 stream::{self, BoxStream},
14};
15use google_ai::GoogleModelMode;
16use gpui::{App, AppContext, AsyncApp, Context, Task};
17use http_client::http::{HeaderMap, HeaderValue};
18use http_client::{
19 AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
20};
21use language_model::{
22 ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
23 LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
24 LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
25 LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
26 LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
27 OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
28 ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
29};
30
31use schemars::JsonSchema;
32use semver::Version;
33use serde::{Deserialize, Serialize, de::DeserializeOwned};
34use smol::io::{AsyncReadExt, BufReader};
35use std::collections::VecDeque;
36use std::pin::Pin;
37use std::str::FromStr;
38use std::sync::Arc;
39use std::task::Poll;
40use std::time::Duration;
41use thiserror::Error;
42
43use anthropic::completion::{
44 AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
45};
46use google_ai::completion::{GoogleEventMapper, into_google};
47use open_ai::completion::{
48 OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
49 into_open_ai_response,
50};
51use x_ai::completion::count_xai_tokens;
52
53const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
54const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
55
56/// Trait for acquiring and refreshing LLM authentication tokens.
57pub trait CloudLlmTokenProvider: Send + Sync {
58 type AuthContext: Clone + Send + 'static;
59
60 fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext;
61 fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
62 fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
63}
64
65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
66#[serde(tag = "type", rename_all = "lowercase")]
67pub enum ModelMode {
68 #[default]
69 Default,
70 Thinking {
71 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
72 budget_tokens: Option<u32>,
73 },
74}
75
76impl From<ModelMode> for AnthropicModelMode {
77 fn from(value: ModelMode) -> Self {
78 match value {
79 ModelMode::Default => AnthropicModelMode::Default,
80 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
81 }
82 }
83}
84
85pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
86 pub id: LanguageModelId,
87 pub model: Arc<cloud_llm_client::LanguageModel>,
88 pub token_provider: Arc<TP>,
89 pub http_client: Arc<HttpClientWithUrl>,
90 pub app_version: Option<Version>,
91 pub request_limiter: RateLimiter,
92}
93
94pub struct PerformLlmCompletionResponse {
95 pub response: Response<AsyncBody>,
96 pub includes_status_messages: bool,
97}
98
99impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
100 pub async fn perform_llm_completion(
101 http_client: &HttpClientWithUrl,
102 token_provider: &TP,
103 auth_context: TP::AuthContext,
104 app_version: Option<Version>,
105 body: CompletionBody,
106 ) -> Result<PerformLlmCompletionResponse> {
107 let mut token = token_provider.acquire_token(auth_context.clone()).await?;
108 let mut refreshed_token = false;
109
110 loop {
111 let request = http_client::Request::builder()
112 .method(Method::POST)
113 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
114 .when_some(app_version.as_ref(), |builder, app_version| {
115 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
116 })
117 .header("Content-Type", "application/json")
118 .header("Authorization", format!("Bearer {token}"))
119 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
120 .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
121 .body(serde_json::to_string(&body)?.into())?;
122
123 let mut response = http_client.send(request).await?;
124 let status = response.status();
125 if status.is_success() {
126 let includes_status_messages = response
127 .headers()
128 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
129 .is_some();
130
131 return Ok(PerformLlmCompletionResponse {
132 response,
133 includes_status_messages,
134 });
135 }
136
137 if !refreshed_token && needs_llm_token_refresh(&response) {
138 token = token_provider.refresh_token(auth_context.clone()).await?;
139 refreshed_token = true;
140 continue;
141 }
142
143 if status == StatusCode::PAYMENT_REQUIRED {
144 return Err(anyhow!(PaymentRequiredError));
145 }
146
147 let mut body = String::new();
148 let headers = response.headers().clone();
149 response.body_mut().read_to_string(&mut body).await?;
150 return Err(anyhow!(ApiError {
151 status,
152 body,
153 headers
154 }));
155 }
156 }
157}
158
159fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
160 response
161 .headers()
162 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
163 .is_some()
164 || response
165 .headers()
166 .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
167 .is_some()
168}
169
170#[derive(Debug, Error)]
171#[error("cloud language model request failed with status {status}: {body}")]
172struct ApiError {
173 status: StatusCode,
174 body: String,
175 headers: HeaderMap<HeaderValue>,
176}
177
178/// Represents error responses from Zed's cloud API.
179///
180/// Example JSON for an upstream HTTP error:
181/// ```json
182/// {
183/// "code": "upstream_http_error",
184/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
185/// "upstream_status": 503
186/// }
187/// ```
188#[derive(Debug, serde::Deserialize)]
189struct CloudApiError {
190 code: String,
191 message: String,
192 #[serde(default)]
193 #[serde(deserialize_with = "deserialize_optional_status_code")]
194 upstream_status: Option<StatusCode>,
195 #[serde(default)]
196 retry_after: Option<f64>,
197}
198
199fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
200where
201 D: serde::Deserializer<'de>,
202{
203 let opt: Option<u16> = Option::deserialize(deserializer)?;
204 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
205}
206
207impl From<ApiError> for LanguageModelCompletionError {
208 fn from(error: ApiError) -> Self {
209 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
210 if cloud_error.code.starts_with("upstream_http_") {
211 let status = if let Some(status) = cloud_error.upstream_status {
212 status
213 } else if cloud_error.code.ends_with("_error") {
214 error.status
215 } else {
216 // If there's a status code in the code string (e.g. "upstream_http_429")
217 // then use that; otherwise, see if the JSON contains a status code.
218 cloud_error
219 .code
220 .strip_prefix("upstream_http_")
221 .and_then(|code_str| code_str.parse::<u16>().ok())
222 .and_then(|code| StatusCode::from_u16(code).ok())
223 .unwrap_or(error.status)
224 };
225
226 return LanguageModelCompletionError::UpstreamProviderError {
227 message: cloud_error.message,
228 status,
229 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
230 };
231 }
232
233 return LanguageModelCompletionError::from_http_status(
234 PROVIDER_NAME,
235 error.status,
236 cloud_error.message,
237 None,
238 );
239 }
240
241 let retry_after = None;
242 LanguageModelCompletionError::from_http_status(
243 PROVIDER_NAME,
244 error.status,
245 error.body,
246 retry_after,
247 )
248 }
249}
250
251impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
252 fn id(&self) -> LanguageModelId {
253 self.id.clone()
254 }
255
256 fn name(&self) -> LanguageModelName {
257 LanguageModelName::from(self.model.display_name.clone())
258 }
259
260 fn provider_id(&self) -> LanguageModelProviderId {
261 PROVIDER_ID
262 }
263
264 fn provider_name(&self) -> LanguageModelProviderName {
265 PROVIDER_NAME
266 }
267
268 fn upstream_provider_id(&self) -> LanguageModelProviderId {
269 use cloud_llm_client::LanguageModelProvider::*;
270 match self.model.provider {
271 Anthropic => ANTHROPIC_PROVIDER_ID,
272 OpenAi => OPEN_AI_PROVIDER_ID,
273 Google => GOOGLE_PROVIDER_ID,
274 XAi => X_AI_PROVIDER_ID,
275 }
276 }
277
278 fn upstream_provider_name(&self) -> LanguageModelProviderName {
279 use cloud_llm_client::LanguageModelProvider::*;
280 match self.model.provider {
281 Anthropic => ANTHROPIC_PROVIDER_NAME,
282 OpenAi => OPEN_AI_PROVIDER_NAME,
283 Google => GOOGLE_PROVIDER_NAME,
284 XAi => X_AI_PROVIDER_NAME,
285 }
286 }
287
288 fn is_latest(&self) -> bool {
289 self.model.is_latest
290 }
291
292 fn supports_tools(&self) -> bool {
293 self.model.supports_tools
294 }
295
296 fn supports_images(&self) -> bool {
297 self.model.supports_images
298 }
299
300 fn supports_thinking(&self) -> bool {
301 self.model.supports_thinking
302 }
303
304 fn supports_fast_mode(&self) -> bool {
305 self.model.supports_fast_mode
306 }
307
308 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
309 self.model
310 .supported_effort_levels
311 .iter()
312 .map(|effort_level| LanguageModelEffortLevel {
313 name: effort_level.name.clone().into(),
314 value: effort_level.value.clone().into(),
315 is_default: effort_level.is_default.unwrap_or(false),
316 })
317 .collect()
318 }
319
320 fn supports_streaming_tools(&self) -> bool {
321 self.model.supports_streaming_tools
322 }
323
324 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
325 match choice {
326 LanguageModelToolChoice::Auto
327 | LanguageModelToolChoice::Any
328 | LanguageModelToolChoice::None => true,
329 }
330 }
331
332 fn supports_split_token_display(&self) -> bool {
333 use cloud_llm_client::LanguageModelProvider::*;
334 matches!(self.model.provider, OpenAi | XAi)
335 }
336
337 fn telemetry_id(&self) -> String {
338 format!("zed.dev/{}", self.model.id)
339 }
340
341 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
342 match self.model.provider {
343 cloud_llm_client::LanguageModelProvider::Anthropic
344 | cloud_llm_client::LanguageModelProvider::OpenAi => {
345 LanguageModelToolSchemaFormat::JsonSchema
346 }
347 cloud_llm_client::LanguageModelProvider::Google
348 | cloud_llm_client::LanguageModelProvider::XAi => {
349 LanguageModelToolSchemaFormat::JsonSchemaSubset
350 }
351 }
352 }
353
354 fn max_token_count(&self) -> u64 {
355 self.model.max_token_count as u64
356 }
357
358 fn max_output_tokens(&self) -> Option<u64> {
359 Some(self.model.max_output_tokens as u64)
360 }
361
362 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
363 match &self.model.provider {
364 cloud_llm_client::LanguageModelProvider::Anthropic => {
365 Some(LanguageModelCacheConfiguration {
366 min_total_token: 2_048,
367 should_speculate: true,
368 max_cache_anchors: 4,
369 })
370 }
371 cloud_llm_client::LanguageModelProvider::OpenAi
372 | cloud_llm_client::LanguageModelProvider::XAi
373 | cloud_llm_client::LanguageModelProvider::Google => None,
374 }
375 }
376
377 fn count_tokens(
378 &self,
379 request: LanguageModelRequest,
380 cx: &App,
381 ) -> BoxFuture<'static, Result<u64>> {
382 match self.model.provider {
383 cloud_llm_client::LanguageModelProvider::Anthropic => cx
384 .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
385 .boxed(),
386 cloud_llm_client::LanguageModelProvider::OpenAi => {
387 let model = match open_ai::Model::from_id(&self.model.id.0) {
388 Ok(model) => model,
389 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
390 };
391 cx.background_spawn(async move { count_open_ai_tokens(request, model) })
392 .boxed()
393 }
394 cloud_llm_client::LanguageModelProvider::XAi => {
395 let model = match x_ai::Model::from_id(&self.model.id.0) {
396 Ok(model) => model,
397 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
398 };
399 cx.background_spawn(async move { count_xai_tokens(request, model) })
400 .boxed()
401 }
402 cloud_llm_client::LanguageModelProvider::Google => {
403 let http_client = self.http_client.clone();
404 let token_provider = self.token_provider.clone();
405 let model_id = self.model.id.to_string();
406 let generate_content_request =
407 into_google(request, model_id.clone(), GoogleModelMode::Default);
408 let auth_context = token_provider.auth_context(&cx.to_async());
409 async move {
410 let token = token_provider.acquire_token(auth_context).await?;
411
412 let request_body = CountTokensBody {
413 provider: cloud_llm_client::LanguageModelProvider::Google,
414 model: model_id,
415 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
416 generate_content_request,
417 })?,
418 };
419 let request = http_client::Request::builder()
420 .method(Method::POST)
421 .uri(
422 http_client
423 .build_zed_llm_url("/count_tokens", &[])?
424 .as_ref(),
425 )
426 .header("Content-Type", "application/json")
427 .header("Authorization", format!("Bearer {token}"))
428 .body(serde_json::to_string(&request_body)?.into())?;
429 let mut response = http_client.send(request).await?;
430 let status = response.status();
431 let headers = response.headers().clone();
432 let mut response_body = String::new();
433 response
434 .body_mut()
435 .read_to_string(&mut response_body)
436 .await?;
437
438 if status.is_success() {
439 let response_body: CountTokensResponse =
440 serde_json::from_str(&response_body)?;
441
442 Ok(response_body.tokens as u64)
443 } else {
444 Err(anyhow!(ApiError {
445 status,
446 body: response_body,
447 headers
448 }))
449 }
450 }
451 .boxed()
452 }
453 }
454 }
455
456 fn stream_completion(
457 &self,
458 request: LanguageModelRequest,
459 cx: &AsyncApp,
460 ) -> BoxFuture<
461 'static,
462 Result<
463 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
464 LanguageModelCompletionError,
465 >,
466 > {
467 let thread_id = request.thread_id.clone();
468 let prompt_id = request.prompt_id.clone();
469 let app_version = self.app_version.clone();
470 let thinking_allowed = request.thinking_allowed;
471 let enable_thinking = thinking_allowed && self.model.supports_thinking;
472 let provider_name = provider_name(&self.model.provider);
473 match self.model.provider {
474 cloud_llm_client::LanguageModelProvider::Anthropic => {
475 let effort = request
476 .thinking_effort
477 .as_ref()
478 .and_then(|effort| anthropic::Effort::from_str(effort).ok());
479
480 let mut request = into_anthropic(
481 request,
482 self.model.id.to_string(),
483 1.0,
484 self.model.max_output_tokens as u64,
485 if enable_thinking {
486 AnthropicModelMode::Thinking {
487 budget_tokens: Some(4_096),
488 }
489 } else {
490 AnthropicModelMode::Default
491 },
492 );
493
494 if enable_thinking && effort.is_some() {
495 request.thinking = Some(anthropic::Thinking::Adaptive);
496 request.output_config = Some(anthropic::OutputConfig { effort });
497 }
498
499 let http_client = self.http_client.clone();
500 let token_provider = self.token_provider.clone();
501 let auth_context = token_provider.auth_context(cx);
502 let future = self.request_limiter.stream(async move {
503 let PerformLlmCompletionResponse {
504 response,
505 includes_status_messages,
506 } = Self::perform_llm_completion(
507 &http_client,
508 &*token_provider,
509 auth_context,
510 app_version,
511 CompletionBody {
512 thread_id,
513 prompt_id,
514 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
515 model: request.model.clone(),
516 provider_request: serde_json::to_value(&request)
517 .map_err(|e| anyhow!(e))?,
518 },
519 )
520 .await
521 .map_err(|err| match err.downcast::<ApiError>() {
522 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
523 Err(err) => anyhow!(err),
524 })?;
525
526 let mut mapper = AnthropicEventMapper::new();
527 Ok(map_cloud_completion_events(
528 Box::pin(response_lines(response, includes_status_messages)),
529 &provider_name,
530 move |event| mapper.map_event(event),
531 ))
532 });
533 async move { Ok(future.await?.boxed()) }.boxed()
534 }
535 cloud_llm_client::LanguageModelProvider::OpenAi => {
536 let http_client = self.http_client.clone();
537 let token_provider = self.token_provider.clone();
538 let effort = request
539 .thinking_effort
540 .as_ref()
541 .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
542
543 let mut request = into_open_ai_response(
544 request,
545 &self.model.id.0,
546 self.model.supports_parallel_tool_calls,
547 true,
548 None,
549 None,
550 );
551
552 if enable_thinking && let Some(effort) = effort {
553 request.reasoning = Some(open_ai::responses::ReasoningConfig {
554 effort,
555 summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
556 });
557 }
558
559 let auth_context = token_provider.auth_context(cx);
560 let future = self.request_limiter.stream(async move {
561 let PerformLlmCompletionResponse {
562 response,
563 includes_status_messages,
564 } = Self::perform_llm_completion(
565 &http_client,
566 &*token_provider,
567 auth_context,
568 app_version,
569 CompletionBody {
570 thread_id,
571 prompt_id,
572 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
573 model: request.model.clone(),
574 provider_request: serde_json::to_value(&request)
575 .map_err(|e| anyhow!(e))?,
576 },
577 )
578 .await?;
579
580 let mut mapper = OpenAiResponseEventMapper::new();
581 Ok(map_cloud_completion_events(
582 Box::pin(response_lines(response, includes_status_messages)),
583 &provider_name,
584 move |event| mapper.map_event(event),
585 ))
586 });
587 async move { Ok(future.await?.boxed()) }.boxed()
588 }
589 cloud_llm_client::LanguageModelProvider::XAi => {
590 let http_client = self.http_client.clone();
591 let token_provider = self.token_provider.clone();
592 let request = into_open_ai(
593 request,
594 &self.model.id.0,
595 self.model.supports_parallel_tool_calls,
596 false,
597 None,
598 None,
599 );
600 let auth_context = token_provider.auth_context(cx);
601 let future = self.request_limiter.stream(async move {
602 let PerformLlmCompletionResponse {
603 response,
604 includes_status_messages,
605 } = Self::perform_llm_completion(
606 &http_client,
607 &*token_provider,
608 auth_context,
609 app_version,
610 CompletionBody {
611 thread_id,
612 prompt_id,
613 provider: cloud_llm_client::LanguageModelProvider::XAi,
614 model: request.model.clone(),
615 provider_request: serde_json::to_value(&request)
616 .map_err(|e| anyhow!(e))?,
617 },
618 )
619 .await?;
620
621 let mut mapper = OpenAiEventMapper::new();
622 Ok(map_cloud_completion_events(
623 Box::pin(response_lines(response, includes_status_messages)),
624 &provider_name,
625 move |event| mapper.map_event(event),
626 ))
627 });
628 async move { Ok(future.await?.boxed()) }.boxed()
629 }
630 cloud_llm_client::LanguageModelProvider::Google => {
631 let http_client = self.http_client.clone();
632 let token_provider = self.token_provider.clone();
633 let request =
634 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
635 let auth_context = token_provider.auth_context(cx);
636 let future = self.request_limiter.stream(async move {
637 let PerformLlmCompletionResponse {
638 response,
639 includes_status_messages,
640 } = Self::perform_llm_completion(
641 &http_client,
642 &*token_provider,
643 auth_context,
644 app_version,
645 CompletionBody {
646 thread_id,
647 prompt_id,
648 provider: cloud_llm_client::LanguageModelProvider::Google,
649 model: request.model.model_id.clone(),
650 provider_request: serde_json::to_value(&request)
651 .map_err(|e| anyhow!(e))?,
652 },
653 )
654 .await?;
655
656 let mut mapper = GoogleEventMapper::new();
657 Ok(map_cloud_completion_events(
658 Box::pin(response_lines(response, includes_status_messages)),
659 &provider_name,
660 move |event| mapper.map_event(event),
661 ))
662 });
663 async move { Ok(future.await?.boxed()) }.boxed()
664 }
665 }
666 }
667}
668
669pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
670 token_provider: Arc<TP>,
671 http_client: Arc<HttpClientWithUrl>,
672 app_version: Option<Version>,
673 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
674 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
675 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
676 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
677}
678
679impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
680 pub fn new(
681 token_provider: Arc<TP>,
682 http_client: Arc<HttpClientWithUrl>,
683 app_version: Option<Version>,
684 ) -> Self {
685 Self {
686 token_provider,
687 http_client,
688 app_version,
689 models: Vec::new(),
690 default_model: None,
691 default_fast_model: None,
692 recommended_models: Vec::new(),
693 }
694 }
695
696 pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
697 let http_client = self.http_client.clone();
698 let token_provider = self.token_provider.clone();
699 cx.spawn(async move |this, cx| {
700 let auth_context = token_provider.auth_context(cx);
701 let response =
702 Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
703 this.update(cx, |this, cx| {
704 this.update_models(response);
705 cx.notify();
706 })
707 })
708 }
709
710 async fn fetch_models_request(
711 http_client: &HttpClientWithUrl,
712 token_provider: &TP,
713 auth_context: TP::AuthContext,
714 ) -> Result<ListModelsResponse> {
715 let token = token_provider.acquire_token(auth_context).await?;
716
717 let request = http_client::Request::builder()
718 .method(Method::GET)
719 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
720 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
721 .header("Authorization", format!("Bearer {token}"))
722 .body(AsyncBody::empty())?;
723 let mut response = http_client
724 .send(request)
725 .await
726 .context("failed to send list models request")?;
727
728 if response.status().is_success() {
729 let mut body = String::new();
730 response.body_mut().read_to_string(&mut body).await?;
731 Ok(serde_json::from_str(&body)?)
732 } else {
733 let mut body = String::new();
734 response.body_mut().read_to_string(&mut body).await?;
735 anyhow::bail!(
736 "error listing models.\nStatus: {:?}\nBody: {body}",
737 response.status(),
738 );
739 }
740 }
741
742 pub fn update_models(&mut self, response: ListModelsResponse) {
743 let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
744
745 self.default_model = models
746 .iter()
747 .find(|model| {
748 response
749 .default_model
750 .as_ref()
751 .is_some_and(|default_model_id| &model.id == default_model_id)
752 })
753 .cloned();
754 self.default_fast_model = models
755 .iter()
756 .find(|model| {
757 response
758 .default_fast_model
759 .as_ref()
760 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
761 })
762 .cloned();
763 self.recommended_models = response
764 .recommended_models
765 .iter()
766 .filter_map(|id| models.iter().find(|model| &model.id == id))
767 .cloned()
768 .collect();
769 self.models = models;
770 }
771
772 pub fn create_model(
773 &self,
774 model: &Arc<cloud_llm_client::LanguageModel>,
775 ) -> Arc<dyn LanguageModel> {
776 Arc::new(CloudLanguageModel::<TP> {
777 id: LanguageModelId::from(model.id.0.to_string()),
778 model: model.clone(),
779 token_provider: self.token_provider.clone(),
780 http_client: self.http_client.clone(),
781 app_version: self.app_version.clone(),
782 request_limiter: RateLimiter::new(4),
783 })
784 }
785
786 pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
787 &self.models
788 }
789
790 pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
791 self.default_model.as_ref()
792 }
793
794 pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
795 self.default_fast_model.as_ref()
796 }
797
798 pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
799 &self.recommended_models
800 }
801}
802
803pub fn map_cloud_completion_events<T, F>(
804 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
805 provider: &LanguageModelProviderName,
806 mut map_callback: F,
807) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
808where
809 T: DeserializeOwned + 'static,
810 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
811 + Send
812 + 'static,
813{
814 let provider = provider.clone();
815 let mut stream = stream.fuse();
816
817 let mut saw_stream_ended = false;
818
819 let mut done = false;
820 let mut pending = VecDeque::new();
821
822 stream::poll_fn(move |cx| {
823 loop {
824 if let Some(item) = pending.pop_front() {
825 return Poll::Ready(Some(item));
826 }
827
828 if done {
829 return Poll::Ready(None);
830 }
831
832 match stream.poll_next_unpin(cx) {
833 Poll::Ready(Some(event)) => {
834 let items = match event {
835 Err(error) => {
836 vec![Err(LanguageModelCompletionError::from(error))]
837 }
838 Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
839 saw_stream_ended = true;
840 vec![]
841 }
842 Ok(CompletionEvent::Status(status)) => {
843 LanguageModelCompletionEvent::from_completion_request_status(
844 status,
845 provider.clone(),
846 )
847 .transpose()
848 .map(|event| vec![event])
849 .unwrap_or_default()
850 }
851 Ok(CompletionEvent::Event(event)) => map_callback(event),
852 };
853 pending.extend(items);
854 }
855 Poll::Ready(None) => {
856 done = true;
857
858 if !saw_stream_ended {
859 return Poll::Ready(Some(Err(
860 LanguageModelCompletionError::StreamEndedUnexpectedly {
861 provider: provider.clone(),
862 },
863 )));
864 }
865 }
866 Poll::Pending => return Poll::Pending,
867 }
868 }
869 })
870 .boxed()
871}
872
873pub fn provider_name(
874 provider: &cloud_llm_client::LanguageModelProvider,
875) -> LanguageModelProviderName {
876 match provider {
877 cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
878 cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
879 cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
880 cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
881 }
882}
883
884pub fn response_lines<T: DeserializeOwned>(
885 response: Response<AsyncBody>,
886 includes_status_messages: bool,
887) -> impl Stream<Item = Result<CompletionEvent<T>>> {
888 futures::stream::try_unfold(
889 (String::new(), BufReader::new(response.into_body())),
890 move |(mut line, mut body)| async move {
891 match body.read_line(&mut line).await {
892 Ok(0) => Ok(None),
893 Ok(_) => {
894 let event = if includes_status_messages {
895 serde_json::from_str::<CompletionEvent<T>>(&line)?
896 } else {
897 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
898 };
899
900 line.clear();
901 Ok(Some((event, (line, body))))
902 }
903 Err(e) => Err(e.into()),
904 }
905 },
906 )
907}
908
909#[cfg(test)]
910mod tests {
911 use super::*;
912 use http_client::http::{HeaderMap, StatusCode};
913 use language_model::LanguageModelCompletionError;
914
915 #[test]
916 fn test_api_error_conversion_with_upstream_http_error() {
917 // upstream_http_error with 503 status should become ServerOverloaded
918 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
919
920 let api_error = ApiError {
921 status: StatusCode::INTERNAL_SERVER_ERROR,
922 body: error_body.to_string(),
923 headers: HeaderMap::new(),
924 };
925
926 let completion_error: LanguageModelCompletionError = api_error.into();
927
928 match completion_error {
929 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
930 assert_eq!(
931 message,
932 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
933 );
934 }
935 _ => panic!(
936 "Expected UpstreamProviderError for upstream 503, got: {:?}",
937 completion_error
938 ),
939 }
940
941 // upstream_http_error with 500 status should become ApiInternalServerError
942 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
943
944 let api_error = ApiError {
945 status: StatusCode::INTERNAL_SERVER_ERROR,
946 body: error_body.to_string(),
947 headers: HeaderMap::new(),
948 };
949
950 let completion_error: LanguageModelCompletionError = api_error.into();
951
952 match completion_error {
953 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
954 assert_eq!(
955 message,
956 "Received an error from the OpenAI API: internal server error"
957 );
958 }
959 _ => panic!(
960 "Expected UpstreamProviderError for upstream 500, got: {:?}",
961 completion_error
962 ),
963 }
964
965 // upstream_http_error with 429 status should become RateLimitExceeded
966 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
967
968 let api_error = ApiError {
969 status: StatusCode::INTERNAL_SERVER_ERROR,
970 body: error_body.to_string(),
971 headers: HeaderMap::new(),
972 };
973
974 let completion_error: LanguageModelCompletionError = api_error.into();
975
976 match completion_error {
977 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
978 assert_eq!(
979 message,
980 "Received an error from the Google API: rate limit exceeded"
981 );
982 }
983 _ => panic!(
984 "Expected UpstreamProviderError for upstream 429, got: {:?}",
985 completion_error
986 ),
987 }
988
989 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
990 let error_body = "Regular internal server error";
991
992 let api_error = ApiError {
993 status: StatusCode::INTERNAL_SERVER_ERROR,
994 body: error_body.to_string(),
995 headers: HeaderMap::new(),
996 };
997
998 let completion_error: LanguageModelCompletionError = api_error.into();
999
1000 match completion_error {
1001 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1002 assert_eq!(provider, PROVIDER_NAME);
1003 assert_eq!(message, "Regular internal server error");
1004 }
1005 _ => panic!(
1006 "Expected ApiInternalServerError for regular 500, got: {:?}",
1007 completion_error
1008 ),
1009 }
1010
1011 // upstream_http_429 format should be converted to UpstreamProviderError
1012 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1013
1014 let api_error = ApiError {
1015 status: StatusCode::INTERNAL_SERVER_ERROR,
1016 body: error_body.to_string(),
1017 headers: HeaderMap::new(),
1018 };
1019
1020 let completion_error: LanguageModelCompletionError = api_error.into();
1021
1022 match completion_error {
1023 LanguageModelCompletionError::UpstreamProviderError {
1024 message,
1025 status,
1026 retry_after,
1027 } => {
1028 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1029 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1030 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1031 }
1032 _ => panic!(
1033 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1034 completion_error
1035 ),
1036 }
1037
1038 // Invalid JSON in error body should fall back to regular error handling
1039 let error_body = "Not JSON at all";
1040
1041 let api_error = ApiError {
1042 status: StatusCode::INTERNAL_SERVER_ERROR,
1043 body: error_body.to_string(),
1044 headers: HeaderMap::new(),
1045 };
1046
1047 let completion_error: LanguageModelCompletionError = api_error.into();
1048
1049 match completion_error {
1050 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1051 assert_eq!(provider, PROVIDER_NAME);
1052 }
1053 _ => panic!(
1054 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1055 completion_error
1056 ),
1057 }
1058 }
1059}