language_models_cloud.rs

  1use anthropic::AnthropicModelMode;
  2use anyhow::{Context as _, Result, anyhow};
  3use cloud_llm_client::{
  4    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
  5    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
  6    EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, OUTDATED_LLM_TOKEN_HEADER_NAME,
  7    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  8};
  9use futures::{
 10    AsyncBufReadExt, AsyncReadExt as _, FutureExt, Stream, StreamExt,
 11    future::BoxFuture,
 12    io::BufReader,
 13    stream::{self, BoxStream},
 14};
 15use google_ai::GoogleModelMode;
 16use gpui::{AppContext, AsyncApp, Context, Task};
 17use http_client::http::{HeaderMap, HeaderValue};
 18use http_client::{
 19    AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
 20};
 21use language_model::{
 22    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
 23    LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
 24    LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
 25    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
 26    LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
 27    OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
 28    ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
 29};
 30
 31use schemars::JsonSchema;
 32use semver::Version;
 33use serde::{Deserialize, Serialize, de::DeserializeOwned};
 34use std::collections::VecDeque;
 35use std::pin::Pin;
 36use std::str::FromStr;
 37use std::sync::Arc;
 38use std::task::Poll;
 39use std::time::Duration;
 40use thiserror::Error;
 41
 42use anthropic::completion::{AnthropicEventMapper, into_anthropic};
 43use google_ai::completion::{GoogleEventMapper, into_google};
 44use open_ai::completion::{
 45    OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response,
 46};
 47
 48const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
 49const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
 50
 51/// Trait for acquiring and refreshing LLM authentication tokens.
 52pub trait CloudLlmTokenProvider: Send + Sync {
 53    type AuthContext: Clone + Send + 'static;
 54
 55    fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
 56    fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
 57    fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
 58}
 59
 60#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 61#[serde(tag = "type", rename_all = "lowercase")]
 62pub enum ModelMode {
 63    #[default]
 64    Default,
 65    Thinking {
 66        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 67        budget_tokens: Option<u32>,
 68    },
 69}
 70
 71impl From<ModelMode> for AnthropicModelMode {
 72    fn from(value: ModelMode) -> Self {
 73        match value {
 74            ModelMode::Default => AnthropicModelMode::Default,
 75            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 76        }
 77    }
 78}
 79
 80pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
 81    pub id: LanguageModelId,
 82    pub model: Arc<cloud_llm_client::LanguageModel>,
 83    pub token_provider: Arc<TP>,
 84    pub http_client: Arc<HttpClientWithUrl>,
 85    pub app_version: Option<Version>,
 86    pub request_limiter: RateLimiter,
 87}
 88
 89pub struct PerformLlmCompletionResponse {
 90    pub response: Response<AsyncBody>,
 91    pub includes_status_messages: bool,
 92}
 93
 94impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
 95    pub async fn perform_llm_completion(
 96        http_client: &HttpClientWithUrl,
 97        token_provider: &TP,
 98        auth_context: TP::AuthContext,
 99        app_version: Option<Version>,
100        body: CompletionBody,
101    ) -> Result<PerformLlmCompletionResponse> {
102        let mut token = token_provider.acquire_token(auth_context.clone()).await?;
103        let mut refreshed_token = false;
104
105        loop {
106            let request = http_client::Request::builder()
107                .method(Method::POST)
108                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
109                .when_some(app_version.as_ref(), |builder, app_version| {
110                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
111                })
112                .header("Content-Type", "application/json")
113                .header("Authorization", format!("Bearer {token}"))
114                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
115                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
116                .body(serde_json::to_string(&body)?.into())?;
117
118            let mut response = http_client.send(request).await?;
119            let status = response.status();
120            if status.is_success() {
121                let includes_status_messages = response
122                    .headers()
123                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
124                    .is_some();
125
126                return Ok(PerformLlmCompletionResponse {
127                    response,
128                    includes_status_messages,
129                });
130            }
131
132            if !refreshed_token && needs_llm_token_refresh(&response) {
133                token = token_provider.refresh_token(auth_context.clone()).await?;
134                refreshed_token = true;
135                continue;
136            }
137
138            if status == StatusCode::PAYMENT_REQUIRED {
139                return Err(anyhow!(PaymentRequiredError));
140            }
141
142            let mut body = String::new();
143            let headers = response.headers().clone();
144            response.body_mut().read_to_string(&mut body).await?;
145            return Err(anyhow!(ApiError {
146                status,
147                body,
148                headers
149            }));
150        }
151    }
152}
153
154fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
155    response
156        .headers()
157        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
158        .is_some()
159        || response
160            .headers()
161            .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
162            .is_some()
163}
164
165#[derive(Debug, Error)]
166#[error("cloud language model request failed with status {status}: {body}")]
167struct ApiError {
168    status: StatusCode,
169    body: String,
170    headers: HeaderMap<HeaderValue>,
171}
172
173/// Represents error responses from Zed's cloud API.
174///
175/// Example JSON for an upstream HTTP error:
176/// ```json
177/// {
178///   "code": "upstream_http_error",
179///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
180///   "upstream_status": 503
181/// }
182/// ```
183#[derive(Debug, serde::Deserialize)]
184struct CloudApiError {
185    code: String,
186    message: String,
187    #[serde(default)]
188    #[serde(deserialize_with = "deserialize_optional_status_code")]
189    upstream_status: Option<StatusCode>,
190    #[serde(default)]
191    retry_after: Option<f64>,
192}
193
194fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
195where
196    D: serde::Deserializer<'de>,
197{
198    let opt: Option<u16> = Option::deserialize(deserializer)?;
199    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
200}
201
202impl From<ApiError> for LanguageModelCompletionError {
203    fn from(error: ApiError) -> Self {
204        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
205            if cloud_error.code.starts_with("upstream_http_") {
206                let status = if let Some(status) = cloud_error.upstream_status {
207                    status
208                } else if cloud_error.code.ends_with("_error") {
209                    error.status
210                } else {
211                    // If there's a status code in the code string (e.g. "upstream_http_429")
212                    // then use that; otherwise, see if the JSON contains a status code.
213                    cloud_error
214                        .code
215                        .strip_prefix("upstream_http_")
216                        .and_then(|code_str| code_str.parse::<u16>().ok())
217                        .and_then(|code| StatusCode::from_u16(code).ok())
218                        .unwrap_or(error.status)
219                };
220
221                return LanguageModelCompletionError::UpstreamProviderError {
222                    message: cloud_error.message,
223                    status,
224                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
225                };
226            }
227
228            return LanguageModelCompletionError::from_http_status(
229                PROVIDER_NAME,
230                error.status,
231                cloud_error.message,
232                None,
233            );
234        }
235
236        let retry_after = None;
237        LanguageModelCompletionError::from_http_status(
238            PROVIDER_NAME,
239            error.status,
240            error.body,
241            retry_after,
242        )
243    }
244}
245
246impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
247    fn id(&self) -> LanguageModelId {
248        self.id.clone()
249    }
250
251    fn name(&self) -> LanguageModelName {
252        LanguageModelName::from(self.model.display_name.clone())
253    }
254
255    fn provider_id(&self) -> LanguageModelProviderId {
256        PROVIDER_ID
257    }
258
259    fn provider_name(&self) -> LanguageModelProviderName {
260        PROVIDER_NAME
261    }
262
263    fn upstream_provider_id(&self) -> LanguageModelProviderId {
264        use cloud_llm_client::LanguageModelProvider::*;
265        match self.model.provider {
266            Anthropic => ANTHROPIC_PROVIDER_ID,
267            OpenAi => OPEN_AI_PROVIDER_ID,
268            Google => GOOGLE_PROVIDER_ID,
269            XAi => X_AI_PROVIDER_ID,
270        }
271    }
272
273    fn upstream_provider_name(&self) -> LanguageModelProviderName {
274        use cloud_llm_client::LanguageModelProvider::*;
275        match self.model.provider {
276            Anthropic => ANTHROPIC_PROVIDER_NAME,
277            OpenAi => OPEN_AI_PROVIDER_NAME,
278            Google => GOOGLE_PROVIDER_NAME,
279            XAi => X_AI_PROVIDER_NAME,
280        }
281    }
282
283    fn is_latest(&self) -> bool {
284        self.model.is_latest
285    }
286
287    fn supports_tools(&self) -> bool {
288        self.model.supports_tools
289    }
290
291    fn supports_images(&self) -> bool {
292        self.model.supports_images
293    }
294
295    fn supports_thinking(&self) -> bool {
296        self.model.supports_thinking
297    }
298
299    fn supports_fast_mode(&self) -> bool {
300        self.model.supports_fast_mode
301    }
302
303    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
304        self.model
305            .supported_effort_levels
306            .iter()
307            .map(|effort_level| LanguageModelEffortLevel {
308                name: effort_level.name.clone().into(),
309                value: effort_level.value.clone().into(),
310                is_default: effort_level.is_default.unwrap_or(false),
311            })
312            .collect()
313    }
314
315    fn supports_streaming_tools(&self) -> bool {
316        self.model.supports_streaming_tools
317    }
318
319    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
320        match choice {
321            LanguageModelToolChoice::Auto
322            | LanguageModelToolChoice::Any
323            | LanguageModelToolChoice::None => true,
324        }
325    }
326
327    fn supports_split_token_display(&self) -> bool {
328        use cloud_llm_client::LanguageModelProvider::*;
329        matches!(self.model.provider, OpenAi | XAi)
330    }
331
332    fn telemetry_id(&self) -> String {
333        format!("zed.dev/{}", self.model.id)
334    }
335
336    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
337        match self.model.provider {
338            cloud_llm_client::LanguageModelProvider::Anthropic
339            | cloud_llm_client::LanguageModelProvider::OpenAi => {
340                LanguageModelToolSchemaFormat::JsonSchema
341            }
342            cloud_llm_client::LanguageModelProvider::Google
343            | cloud_llm_client::LanguageModelProvider::XAi => {
344                LanguageModelToolSchemaFormat::JsonSchemaSubset
345            }
346        }
347    }
348
349    fn max_token_count(&self) -> u64 {
350        self.model.max_token_count as u64
351    }
352
353    fn max_output_tokens(&self) -> Option<u64> {
354        Some(self.model.max_output_tokens as u64)
355    }
356
357    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
358        match &self.model.provider {
359            cloud_llm_client::LanguageModelProvider::Anthropic => {
360                Some(LanguageModelCacheConfiguration {
361                    min_total_token: 2_048,
362                    should_speculate: true,
363                    max_cache_anchors: 4,
364                })
365            }
366            cloud_llm_client::LanguageModelProvider::OpenAi
367            | cloud_llm_client::LanguageModelProvider::XAi
368            | cloud_llm_client::LanguageModelProvider::Google => None,
369        }
370    }
371
372    fn stream_completion(
373        &self,
374        request: LanguageModelRequest,
375        cx: &AsyncApp,
376    ) -> BoxFuture<
377        'static,
378        Result<
379            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
380            LanguageModelCompletionError,
381        >,
382    > {
383        let thread_id = request.thread_id.clone();
384        let prompt_id = request.prompt_id.clone();
385        let app_version = self.app_version.clone();
386        let thinking_allowed = request.thinking_allowed;
387        let enable_thinking = thinking_allowed && self.model.supports_thinking;
388        let provider_name = provider_name(&self.model.provider);
389        match self.model.provider {
390            cloud_llm_client::LanguageModelProvider::Anthropic => {
391                let effort = request
392                    .thinking_effort
393                    .as_ref()
394                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
395
396                let mut request = into_anthropic(
397                    request,
398                    self.model.id.to_string(),
399                    1.0,
400                    self.model.max_output_tokens as u64,
401                    if enable_thinking {
402                        AnthropicModelMode::Thinking {
403                            budget_tokens: Some(4_096),
404                        }
405                    } else {
406                        AnthropicModelMode::Default
407                    },
408                );
409
410                if enable_thinking && effort.is_some() {
411                    request.thinking = Some(anthropic::Thinking::Adaptive {
412                        display: Some(anthropic::AdaptiveThinkingDisplay::Summarized),
413                    });
414                    request.output_config = Some(anthropic::OutputConfig { effort });
415                }
416
417                if !self.model.supports_fast_mode {
418                    request.speed = None;
419                }
420
421                let http_client = self.http_client.clone();
422                let token_provider = self.token_provider.clone();
423                let auth_context = token_provider.auth_context(cx);
424                let future = self.request_limiter.stream(async move {
425                    let PerformLlmCompletionResponse {
426                        response,
427                        includes_status_messages,
428                    } = Self::perform_llm_completion(
429                        &http_client,
430                        &*token_provider,
431                        auth_context,
432                        app_version,
433                        CompletionBody {
434                            thread_id,
435                            prompt_id,
436                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
437                            model: request.model.clone(),
438                            provider_request: serde_json::to_value(&request)
439                                .map_err(|e| anyhow!(e))?,
440                        },
441                    )
442                    .await
443                    .map_err(|err| match err.downcast::<ApiError>() {
444                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
445                        Err(err) => anyhow!(err),
446                    })?;
447
448                    let mut mapper = AnthropicEventMapper::new();
449                    Ok(map_cloud_completion_events(
450                        Box::pin(response_lines(response, includes_status_messages)),
451                        &provider_name,
452                        move |event| mapper.map_event(event),
453                    ))
454                });
455                async move { Ok(future.await?.boxed()) }.boxed()
456            }
457            cloud_llm_client::LanguageModelProvider::OpenAi => {
458                let http_client = self.http_client.clone();
459                let token_provider = self.token_provider.clone();
460                let effort = request
461                    .thinking_effort
462                    .as_ref()
463                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
464
465                let mut request = into_open_ai_response(
466                    request,
467                    &self.model.id.0,
468                    self.model.supports_parallel_tool_calls,
469                    true,
470                    None,
471                    None,
472                );
473
474                if enable_thinking && let Some(effort) = effort {
475                    request.reasoning = Some(open_ai::responses::ReasoningConfig {
476                        effort,
477                        summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
478                    });
479                }
480
481                let auth_context = token_provider.auth_context(cx);
482                let future = self.request_limiter.stream(async move {
483                    let PerformLlmCompletionResponse {
484                        response,
485                        includes_status_messages,
486                    } = Self::perform_llm_completion(
487                        &http_client,
488                        &*token_provider,
489                        auth_context,
490                        app_version,
491                        CompletionBody {
492                            thread_id,
493                            prompt_id,
494                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
495                            model: request.model.clone(),
496                            provider_request: serde_json::to_value(&request)
497                                .map_err(|e| anyhow!(e))?,
498                        },
499                    )
500                    .await?;
501
502                    let mut mapper = OpenAiResponseEventMapper::new();
503                    Ok(map_cloud_completion_events(
504                        Box::pin(response_lines(response, includes_status_messages)),
505                        &provider_name,
506                        move |event| mapper.map_event(event),
507                    ))
508                });
509                async move { Ok(future.await?.boxed()) }.boxed()
510            }
511            cloud_llm_client::LanguageModelProvider::XAi => {
512                let http_client = self.http_client.clone();
513                let token_provider = self.token_provider.clone();
514                let request = into_open_ai(
515                    request,
516                    &self.model.id.0,
517                    self.model.supports_parallel_tool_calls,
518                    false,
519                    None,
520                    None,
521                    false,
522                );
523                let auth_context = token_provider.auth_context(cx);
524                let future = self.request_limiter.stream(async move {
525                    let PerformLlmCompletionResponse {
526                        response,
527                        includes_status_messages,
528                    } = Self::perform_llm_completion(
529                        &http_client,
530                        &*token_provider,
531                        auth_context,
532                        app_version,
533                        CompletionBody {
534                            thread_id,
535                            prompt_id,
536                            provider: cloud_llm_client::LanguageModelProvider::XAi,
537                            model: request.model.clone(),
538                            provider_request: serde_json::to_value(&request)
539                                .map_err(|e| anyhow!(e))?,
540                        },
541                    )
542                    .await?;
543
544                    let mut mapper = OpenAiEventMapper::new();
545                    Ok(map_cloud_completion_events(
546                        Box::pin(response_lines(response, includes_status_messages)),
547                        &provider_name,
548                        move |event| mapper.map_event(event),
549                    ))
550                });
551                async move { Ok(future.await?.boxed()) }.boxed()
552            }
553            cloud_llm_client::LanguageModelProvider::Google => {
554                let http_client = self.http_client.clone();
555                let token_provider = self.token_provider.clone();
556                let request =
557                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
558                let auth_context = token_provider.auth_context(cx);
559                let future = self.request_limiter.stream(async move {
560                    let PerformLlmCompletionResponse {
561                        response,
562                        includes_status_messages,
563                    } = Self::perform_llm_completion(
564                        &http_client,
565                        &*token_provider,
566                        auth_context,
567                        app_version,
568                        CompletionBody {
569                            thread_id,
570                            prompt_id,
571                            provider: cloud_llm_client::LanguageModelProvider::Google,
572                            model: request.model.model_id.clone(),
573                            provider_request: serde_json::to_value(&request)
574                                .map_err(|e| anyhow!(e))?,
575                        },
576                    )
577                    .await?;
578
579                    let mut mapper = GoogleEventMapper::new();
580                    Ok(map_cloud_completion_events(
581                        Box::pin(response_lines(response, includes_status_messages)),
582                        &provider_name,
583                        move |event| mapper.map_event(event),
584                    ))
585                });
586                async move { Ok(future.await?.boxed()) }.boxed()
587            }
588        }
589    }
590}
591
592pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
593    token_provider: Arc<TP>,
594    http_client: Arc<HttpClientWithUrl>,
595    app_version: Option<Version>,
596    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
597    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
598    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
599    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
600}
601
602impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
603    pub fn new(
604        token_provider: Arc<TP>,
605        http_client: Arc<HttpClientWithUrl>,
606        app_version: Option<Version>,
607    ) -> Self {
608        Self {
609            token_provider,
610            http_client,
611            app_version,
612            models: Vec::new(),
613            default_model: None,
614            default_fast_model: None,
615            recommended_models: Vec::new(),
616        }
617    }
618
619    pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
620        let http_client = self.http_client.clone();
621        let token_provider = self.token_provider.clone();
622        cx.spawn(async move |this, cx| {
623            let auth_context = token_provider.auth_context(cx);
624            let response =
625                Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
626            this.update(cx, |this, cx| {
627                this.update_models(response);
628                cx.notify();
629            })
630        })
631    }
632
633    async fn fetch_models_request(
634        http_client: &HttpClientWithUrl,
635        token_provider: &TP,
636        auth_context: TP::AuthContext,
637    ) -> Result<ListModelsResponse> {
638        let token = token_provider.acquire_token(auth_context).await?;
639
640        let request = http_client::Request::builder()
641            .method(Method::GET)
642            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
643            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
644            .header("Authorization", format!("Bearer {token}"))
645            .body(AsyncBody::empty())?;
646        let mut response = http_client
647            .send(request)
648            .await
649            .context("failed to send list models request")?;
650
651        if response.status().is_success() {
652            let mut body = String::new();
653            response.body_mut().read_to_string(&mut body).await?;
654            Ok(serde_json::from_str(&body)?)
655        } else {
656            let mut body = String::new();
657            response.body_mut().read_to_string(&mut body).await?;
658            anyhow::bail!(
659                "error listing models.\nStatus: {:?}\nBody: {body}",
660                response.status(),
661            );
662        }
663    }
664
665    pub fn update_models(&mut self, response: ListModelsResponse) {
666        let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
667
668        self.default_model = models
669            .iter()
670            .find(|model| {
671                response
672                    .default_model
673                    .as_ref()
674                    .is_some_and(|default_model_id| &model.id == default_model_id)
675            })
676            .cloned();
677        self.default_fast_model = models
678            .iter()
679            .find(|model| {
680                response
681                    .default_fast_model
682                    .as_ref()
683                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
684            })
685            .cloned();
686        self.recommended_models = response
687            .recommended_models
688            .iter()
689            .filter_map(|id| models.iter().find(|model| &model.id == id))
690            .cloned()
691            .collect();
692        self.models = models;
693    }
694
695    pub fn create_model(
696        &self,
697        model: &Arc<cloud_llm_client::LanguageModel>,
698    ) -> Arc<dyn LanguageModel> {
699        Arc::new(CloudLanguageModel::<TP> {
700            id: LanguageModelId::from(model.id.0.to_string()),
701            model: model.clone(),
702            token_provider: self.token_provider.clone(),
703            http_client: self.http_client.clone(),
704            app_version: self.app_version.clone(),
705            request_limiter: RateLimiter::new(4),
706        })
707    }
708
709    pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
710        &self.models
711    }
712
713    pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
714        self.default_model.as_ref()
715    }
716
717    pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
718        self.default_fast_model.as_ref()
719    }
720
721    pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
722        &self.recommended_models
723    }
724}
725
726pub fn map_cloud_completion_events<T, F>(
727    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
728    provider: &LanguageModelProviderName,
729    mut map_callback: F,
730) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
731where
732    T: DeserializeOwned + 'static,
733    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
734        + Send
735        + 'static,
736{
737    let provider = provider.clone();
738    let mut stream = stream.fuse();
739
740    let mut saw_stream_ended = false;
741
742    let mut done = false;
743    let mut pending = VecDeque::new();
744
745    stream::poll_fn(move |cx| {
746        loop {
747            if let Some(item) = pending.pop_front() {
748                return Poll::Ready(Some(item));
749            }
750
751            if done {
752                return Poll::Ready(None);
753            }
754
755            match stream.poll_next_unpin(cx) {
756                Poll::Ready(Some(event)) => {
757                    let items = match event {
758                        Err(error) => {
759                            vec![Err(LanguageModelCompletionError::from(error))]
760                        }
761                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
762                            saw_stream_ended = true;
763                            vec![]
764                        }
765                        Ok(CompletionEvent::Status(status)) => {
766                            LanguageModelCompletionEvent::from_completion_request_status(
767                                status,
768                                provider.clone(),
769                            )
770                            .transpose()
771                            .map(|event| vec![event])
772                            .unwrap_or_default()
773                        }
774                        Ok(CompletionEvent::Event(event)) => map_callback(event),
775                    };
776                    pending.extend(items);
777                }
778                Poll::Ready(None) => {
779                    done = true;
780
781                    if !saw_stream_ended {
782                        return Poll::Ready(Some(Err(
783                            LanguageModelCompletionError::StreamEndedUnexpectedly {
784                                provider: provider.clone(),
785                            },
786                        )));
787                    }
788                }
789                Poll::Pending => return Poll::Pending,
790            }
791        }
792    })
793    .boxed()
794}
795
796pub fn provider_name(
797    provider: &cloud_llm_client::LanguageModelProvider,
798) -> LanguageModelProviderName {
799    match provider {
800        cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
801        cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
802        cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
803        cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
804    }
805}
806
807pub fn response_lines<T: DeserializeOwned>(
808    response: Response<AsyncBody>,
809    includes_status_messages: bool,
810) -> impl Stream<Item = Result<CompletionEvent<T>>> {
811    futures::stream::try_unfold(
812        (String::new(), BufReader::new(response.into_body())),
813        move |(mut line, mut body)| async move {
814            match body.read_line(&mut line).await {
815                Ok(0) => Ok(None),
816                Ok(_) => {
817                    let event = if includes_status_messages {
818                        serde_json::from_str::<CompletionEvent<T>>(&line)?
819                    } else {
820                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
821                    };
822
823                    line.clear();
824                    Ok(Some((event, (line, body))))
825                }
826                Err(e) => Err(e.into()),
827            }
828        },
829    )
830}
831
832#[cfg(test)]
833mod tests {
834    use super::*;
835    use http_client::http::{HeaderMap, StatusCode};
836    use language_model::LanguageModelCompletionError;
837
838    #[test]
839    fn test_api_error_conversion_with_upstream_http_error() {
840        // upstream_http_error with 503 status should become ServerOverloaded
841        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
842
843        let api_error = ApiError {
844            status: StatusCode::INTERNAL_SERVER_ERROR,
845            body: error_body.to_string(),
846            headers: HeaderMap::new(),
847        };
848
849        let completion_error: LanguageModelCompletionError = api_error.into();
850
851        match completion_error {
852            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
853                assert_eq!(
854                    message,
855                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
856                );
857            }
858            _ => panic!(
859                "Expected UpstreamProviderError for upstream 503, got: {:?}",
860                completion_error
861            ),
862        }
863
864        // upstream_http_error with 500 status should become ApiInternalServerError
865        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
866
867        let api_error = ApiError {
868            status: StatusCode::INTERNAL_SERVER_ERROR,
869            body: error_body.to_string(),
870            headers: HeaderMap::new(),
871        };
872
873        let completion_error: LanguageModelCompletionError = api_error.into();
874
875        match completion_error {
876            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
877                assert_eq!(
878                    message,
879                    "Received an error from the OpenAI API: internal server error"
880                );
881            }
882            _ => panic!(
883                "Expected UpstreamProviderError for upstream 500, got: {:?}",
884                completion_error
885            ),
886        }
887
888        // upstream_http_error with 429 status should become RateLimitExceeded
889        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
890
891        let api_error = ApiError {
892            status: StatusCode::INTERNAL_SERVER_ERROR,
893            body: error_body.to_string(),
894            headers: HeaderMap::new(),
895        };
896
897        let completion_error: LanguageModelCompletionError = api_error.into();
898
899        match completion_error {
900            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
901                assert_eq!(
902                    message,
903                    "Received an error from the Google API: rate limit exceeded"
904                );
905            }
906            _ => panic!(
907                "Expected UpstreamProviderError for upstream 429, got: {:?}",
908                completion_error
909            ),
910        }
911
912        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
913        let error_body = "Regular internal server error";
914
915        let api_error = ApiError {
916            status: StatusCode::INTERNAL_SERVER_ERROR,
917            body: error_body.to_string(),
918            headers: HeaderMap::new(),
919        };
920
921        let completion_error: LanguageModelCompletionError = api_error.into();
922
923        match completion_error {
924            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
925                assert_eq!(provider, PROVIDER_NAME);
926                assert_eq!(message, "Regular internal server error");
927            }
928            _ => panic!(
929                "Expected ApiInternalServerError for regular 500, got: {:?}",
930                completion_error
931            ),
932        }
933
934        // upstream_http_429 format should be converted to UpstreamProviderError
935        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
936
937        let api_error = ApiError {
938            status: StatusCode::INTERNAL_SERVER_ERROR,
939            body: error_body.to_string(),
940            headers: HeaderMap::new(),
941        };
942
943        let completion_error: LanguageModelCompletionError = api_error.into();
944
945        match completion_error {
946            LanguageModelCompletionError::UpstreamProviderError {
947                message,
948                status,
949                retry_after,
950            } => {
951                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
952                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
953                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
954            }
955            _ => panic!(
956                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
957                completion_error
958            ),
959        }
960
961        // Invalid JSON in error body should fall back to regular error handling
962        let error_body = "Not JSON at all";
963
964        let api_error = ApiError {
965            status: StatusCode::INTERNAL_SERVER_ERROR,
966            body: error_body.to_string(),
967            headers: HeaderMap::new(),
968        };
969
970        let completion_error: LanguageModelCompletionError = api_error.into();
971
972        match completion_error {
973            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
974                assert_eq!(provider, PROVIDER_NAME);
975            }
976            _ => panic!(
977                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
978                completion_error
979            ),
980        }
981    }
982}