language_models_cloud.rs

  1use anthropic::AnthropicModelMode;
  2use anyhow::{Context as _, Result, anyhow};
  3use cloud_llm_client::{
  4    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
  5    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
  6    EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, OUTDATED_LLM_TOKEN_HEADER_NAME,
  7    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
  8};
  9use futures::{
 10    AsyncBufReadExt, FutureExt, Stream, StreamExt,
 11    future::BoxFuture,
 12    stream::{self, BoxStream},
 13};
 14use google_ai::GoogleModelMode;
 15use gpui::{AppContext, AsyncApp, Context, Task};
 16use http_client::http::{HeaderMap, HeaderValue};
 17use http_client::{
 18    AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
 19};
 20use language_model::{
 21    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
 22    LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
 23    LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
 24    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
 25    LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
 26    OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
 27    ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
 28};
 29
 30use schemars::JsonSchema;
 31use semver::Version;
 32use serde::{Deserialize, Serialize, de::DeserializeOwned};
 33use smol::io::{AsyncReadExt, BufReader};
 34use std::collections::VecDeque;
 35use std::pin::Pin;
 36use std::str::FromStr;
 37use std::sync::Arc;
 38use std::task::Poll;
 39use std::time::Duration;
 40use thiserror::Error;
 41
 42use anthropic::completion::{AnthropicEventMapper, into_anthropic};
 43use google_ai::completion::{GoogleEventMapper, into_google};
 44use open_ai::completion::{
 45    OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response,
 46};
 47
 48const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
 49const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
 50
 51/// Trait for acquiring and refreshing LLM authentication tokens.
 52pub trait CloudLlmTokenProvider: Send + Sync {
 53    type AuthContext: Clone + Send + 'static;
 54
 55    fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
 56    fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
 57    fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
 58}
 59
 60#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 61#[serde(tag = "type", rename_all = "lowercase")]
 62pub enum ModelMode {
 63    #[default]
 64    Default,
 65    Thinking {
 66        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 67        budget_tokens: Option<u32>,
 68    },
 69}
 70
 71impl From<ModelMode> for AnthropicModelMode {
 72    fn from(value: ModelMode) -> Self {
 73        match value {
 74            ModelMode::Default => AnthropicModelMode::Default,
 75            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
 76        }
 77    }
 78}
 79
 80pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
 81    pub id: LanguageModelId,
 82    pub model: Arc<cloud_llm_client::LanguageModel>,
 83    pub token_provider: Arc<TP>,
 84    pub http_client: Arc<HttpClientWithUrl>,
 85    pub app_version: Option<Version>,
 86    pub request_limiter: RateLimiter,
 87}
 88
 89pub struct PerformLlmCompletionResponse {
 90    pub response: Response<AsyncBody>,
 91    pub includes_status_messages: bool,
 92}
 93
 94impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
 95    pub async fn perform_llm_completion(
 96        http_client: &HttpClientWithUrl,
 97        token_provider: &TP,
 98        auth_context: TP::AuthContext,
 99        app_version: Option<Version>,
100        body: CompletionBody,
101    ) -> Result<PerformLlmCompletionResponse> {
102        let mut token = token_provider.acquire_token(auth_context.clone()).await?;
103        let mut refreshed_token = false;
104
105        loop {
106            let request = http_client::Request::builder()
107                .method(Method::POST)
108                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
109                .when_some(app_version.as_ref(), |builder, app_version| {
110                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
111                })
112                .header("Content-Type", "application/json")
113                .header("Authorization", format!("Bearer {token}"))
114                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
115                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
116                .body(serde_json::to_string(&body)?.into())?;
117
118            let mut response = http_client.send(request).await?;
119            let status = response.status();
120            if status.is_success() {
121                let includes_status_messages = response
122                    .headers()
123                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
124                    .is_some();
125
126                return Ok(PerformLlmCompletionResponse {
127                    response,
128                    includes_status_messages,
129                });
130            }
131
132            if !refreshed_token && needs_llm_token_refresh(&response) {
133                token = token_provider.refresh_token(auth_context.clone()).await?;
134                refreshed_token = true;
135                continue;
136            }
137
138            if status == StatusCode::PAYMENT_REQUIRED {
139                return Err(anyhow!(PaymentRequiredError));
140            }
141
142            let mut body = String::new();
143            let headers = response.headers().clone();
144            response.body_mut().read_to_string(&mut body).await?;
145            return Err(anyhow!(ApiError {
146                status,
147                body,
148                headers
149            }));
150        }
151    }
152}
153
154fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
155    response
156        .headers()
157        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
158        .is_some()
159        || response
160            .headers()
161            .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
162            .is_some()
163}
164
165#[derive(Debug, Error)]
166#[error("cloud language model request failed with status {status}: {body}")]
167struct ApiError {
168    status: StatusCode,
169    body: String,
170    headers: HeaderMap<HeaderValue>,
171}
172
173/// Represents error responses from Zed's cloud API.
174///
175/// Example JSON for an upstream HTTP error:
176/// ```json
177/// {
178///   "code": "upstream_http_error",
179///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
180///   "upstream_status": 503
181/// }
182/// ```
183#[derive(Debug, serde::Deserialize)]
184struct CloudApiError {
185    code: String,
186    message: String,
187    #[serde(default)]
188    #[serde(deserialize_with = "deserialize_optional_status_code")]
189    upstream_status: Option<StatusCode>,
190    #[serde(default)]
191    retry_after: Option<f64>,
192}
193
194fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
195where
196    D: serde::Deserializer<'de>,
197{
198    let opt: Option<u16> = Option::deserialize(deserializer)?;
199    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
200}
201
202impl From<ApiError> for LanguageModelCompletionError {
203    fn from(error: ApiError) -> Self {
204        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
205            if cloud_error.code.starts_with("upstream_http_") {
206                let status = if let Some(status) = cloud_error.upstream_status {
207                    status
208                } else if cloud_error.code.ends_with("_error") {
209                    error.status
210                } else {
211                    // If there's a status code in the code string (e.g. "upstream_http_429")
212                    // then use that; otherwise, see if the JSON contains a status code.
213                    cloud_error
214                        .code
215                        .strip_prefix("upstream_http_")
216                        .and_then(|code_str| code_str.parse::<u16>().ok())
217                        .and_then(|code| StatusCode::from_u16(code).ok())
218                        .unwrap_or(error.status)
219                };
220
221                return LanguageModelCompletionError::UpstreamProviderError {
222                    message: cloud_error.message,
223                    status,
224                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
225                };
226            }
227
228            return LanguageModelCompletionError::from_http_status(
229                PROVIDER_NAME,
230                error.status,
231                cloud_error.message,
232                None,
233            );
234        }
235
236        let retry_after = None;
237        LanguageModelCompletionError::from_http_status(
238            PROVIDER_NAME,
239            error.status,
240            error.body,
241            retry_after,
242        )
243    }
244}
245
246impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
247    fn id(&self) -> LanguageModelId {
248        self.id.clone()
249    }
250
251    fn name(&self) -> LanguageModelName {
252        LanguageModelName::from(self.model.display_name.clone())
253    }
254
255    fn provider_id(&self) -> LanguageModelProviderId {
256        PROVIDER_ID
257    }
258
259    fn provider_name(&self) -> LanguageModelProviderName {
260        PROVIDER_NAME
261    }
262
263    fn upstream_provider_id(&self) -> LanguageModelProviderId {
264        use cloud_llm_client::LanguageModelProvider::*;
265        match self.model.provider {
266            Anthropic => ANTHROPIC_PROVIDER_ID,
267            OpenAi => OPEN_AI_PROVIDER_ID,
268            Google => GOOGLE_PROVIDER_ID,
269            XAi => X_AI_PROVIDER_ID,
270        }
271    }
272
273    fn upstream_provider_name(&self) -> LanguageModelProviderName {
274        use cloud_llm_client::LanguageModelProvider::*;
275        match self.model.provider {
276            Anthropic => ANTHROPIC_PROVIDER_NAME,
277            OpenAi => OPEN_AI_PROVIDER_NAME,
278            Google => GOOGLE_PROVIDER_NAME,
279            XAi => X_AI_PROVIDER_NAME,
280        }
281    }
282
283    fn is_latest(&self) -> bool {
284        self.model.is_latest
285    }
286
287    fn supports_tools(&self) -> bool {
288        self.model.supports_tools
289    }
290
291    fn supports_images(&self) -> bool {
292        self.model.supports_images
293    }
294
295    fn supports_thinking(&self) -> bool {
296        self.model.supports_thinking
297    }
298
299    fn supports_fast_mode(&self) -> bool {
300        self.model.supports_fast_mode
301    }
302
303    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
304        self.model
305            .supported_effort_levels
306            .iter()
307            .map(|effort_level| LanguageModelEffortLevel {
308                name: effort_level.name.clone().into(),
309                value: effort_level.value.clone().into(),
310                is_default: effort_level.is_default.unwrap_or(false),
311            })
312            .collect()
313    }
314
315    fn supports_streaming_tools(&self) -> bool {
316        self.model.supports_streaming_tools
317    }
318
319    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
320        match choice {
321            LanguageModelToolChoice::Auto
322            | LanguageModelToolChoice::Any
323            | LanguageModelToolChoice::None => true,
324        }
325    }
326
327    fn supports_split_token_display(&self) -> bool {
328        use cloud_llm_client::LanguageModelProvider::*;
329        matches!(self.model.provider, OpenAi | XAi)
330    }
331
332    fn telemetry_id(&self) -> String {
333        format!("zed.dev/{}", self.model.id)
334    }
335
336    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
337        match self.model.provider {
338            cloud_llm_client::LanguageModelProvider::Anthropic
339            | cloud_llm_client::LanguageModelProvider::OpenAi => {
340                LanguageModelToolSchemaFormat::JsonSchema
341            }
342            cloud_llm_client::LanguageModelProvider::Google
343            | cloud_llm_client::LanguageModelProvider::XAi => {
344                LanguageModelToolSchemaFormat::JsonSchemaSubset
345            }
346        }
347    }
348
349    fn max_token_count(&self) -> u64 {
350        self.model.max_token_count as u64
351    }
352
353    fn max_output_tokens(&self) -> Option<u64> {
354        Some(self.model.max_output_tokens as u64)
355    }
356
357    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
358        match &self.model.provider {
359            cloud_llm_client::LanguageModelProvider::Anthropic => {
360                Some(LanguageModelCacheConfiguration {
361                    min_total_token: 2_048,
362                    should_speculate: true,
363                    max_cache_anchors: 4,
364                })
365            }
366            cloud_llm_client::LanguageModelProvider::OpenAi
367            | cloud_llm_client::LanguageModelProvider::XAi
368            | cloud_llm_client::LanguageModelProvider::Google => None,
369        }
370    }
371
372    fn stream_completion(
373        &self,
374        request: LanguageModelRequest,
375        cx: &AsyncApp,
376    ) -> BoxFuture<
377        'static,
378        Result<
379            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
380            LanguageModelCompletionError,
381        >,
382    > {
383        let thread_id = request.thread_id.clone();
384        let prompt_id = request.prompt_id.clone();
385        let app_version = self.app_version.clone();
386        let thinking_allowed = request.thinking_allowed;
387        let enable_thinking = thinking_allowed && self.model.supports_thinking;
388        let provider_name = provider_name(&self.model.provider);
389        match self.model.provider {
390            cloud_llm_client::LanguageModelProvider::Anthropic => {
391                let effort = request
392                    .thinking_effort
393                    .as_ref()
394                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
395
396                let mut request = into_anthropic(
397                    request,
398                    self.model.id.to_string(),
399                    1.0,
400                    self.model.max_output_tokens as u64,
401                    if enable_thinking {
402                        AnthropicModelMode::Thinking {
403                            budget_tokens: Some(4_096),
404                        }
405                    } else {
406                        AnthropicModelMode::Default
407                    },
408                );
409
410                if enable_thinking && effort.is_some() {
411                    request.thinking = Some(anthropic::Thinking::Adaptive);
412                    request.output_config = Some(anthropic::OutputConfig { effort });
413                }
414
415                if !self.model.supports_fast_mode {
416                    request.speed = None;
417                }
418
419                let http_client = self.http_client.clone();
420                let token_provider = self.token_provider.clone();
421                let auth_context = token_provider.auth_context(cx);
422                let future = self.request_limiter.stream(async move {
423                    let PerformLlmCompletionResponse {
424                        response,
425                        includes_status_messages,
426                    } = Self::perform_llm_completion(
427                        &http_client,
428                        &*token_provider,
429                        auth_context,
430                        app_version,
431                        CompletionBody {
432                            thread_id,
433                            prompt_id,
434                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
435                            model: request.model.clone(),
436                            provider_request: serde_json::to_value(&request)
437                                .map_err(|e| anyhow!(e))?,
438                        },
439                    )
440                    .await
441                    .map_err(|err| match err.downcast::<ApiError>() {
442                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
443                        Err(err) => anyhow!(err),
444                    })?;
445
446                    let mut mapper = AnthropicEventMapper::new();
447                    Ok(map_cloud_completion_events(
448                        Box::pin(response_lines(response, includes_status_messages)),
449                        &provider_name,
450                        move |event| mapper.map_event(event),
451                    ))
452                });
453                async move { Ok(future.await?.boxed()) }.boxed()
454            }
455            cloud_llm_client::LanguageModelProvider::OpenAi => {
456                let http_client = self.http_client.clone();
457                let token_provider = self.token_provider.clone();
458                let effort = request
459                    .thinking_effort
460                    .as_ref()
461                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
462
463                let mut request = into_open_ai_response(
464                    request,
465                    &self.model.id.0,
466                    self.model.supports_parallel_tool_calls,
467                    true,
468                    None,
469                    None,
470                );
471
472                if enable_thinking && let Some(effort) = effort {
473                    request.reasoning = Some(open_ai::responses::ReasoningConfig {
474                        effort,
475                        summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
476                    });
477                }
478
479                let auth_context = token_provider.auth_context(cx);
480                let future = self.request_limiter.stream(async move {
481                    let PerformLlmCompletionResponse {
482                        response,
483                        includes_status_messages,
484                    } = Self::perform_llm_completion(
485                        &http_client,
486                        &*token_provider,
487                        auth_context,
488                        app_version,
489                        CompletionBody {
490                            thread_id,
491                            prompt_id,
492                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
493                            model: request.model.clone(),
494                            provider_request: serde_json::to_value(&request)
495                                .map_err(|e| anyhow!(e))?,
496                        },
497                    )
498                    .await?;
499
500                    let mut mapper = OpenAiResponseEventMapper::new();
501                    Ok(map_cloud_completion_events(
502                        Box::pin(response_lines(response, includes_status_messages)),
503                        &provider_name,
504                        move |event| mapper.map_event(event),
505                    ))
506                });
507                async move { Ok(future.await?.boxed()) }.boxed()
508            }
509            cloud_llm_client::LanguageModelProvider::XAi => {
510                let http_client = self.http_client.clone();
511                let token_provider = self.token_provider.clone();
512                let request = into_open_ai(
513                    request,
514                    &self.model.id.0,
515                    self.model.supports_parallel_tool_calls,
516                    false,
517                    None,
518                    None,
519                    false,
520                );
521                let auth_context = token_provider.auth_context(cx);
522                let future = self.request_limiter.stream(async move {
523                    let PerformLlmCompletionResponse {
524                        response,
525                        includes_status_messages,
526                    } = Self::perform_llm_completion(
527                        &http_client,
528                        &*token_provider,
529                        auth_context,
530                        app_version,
531                        CompletionBody {
532                            thread_id,
533                            prompt_id,
534                            provider: cloud_llm_client::LanguageModelProvider::XAi,
535                            model: request.model.clone(),
536                            provider_request: serde_json::to_value(&request)
537                                .map_err(|e| anyhow!(e))?,
538                        },
539                    )
540                    .await?;
541
542                    let mut mapper = OpenAiEventMapper::new();
543                    Ok(map_cloud_completion_events(
544                        Box::pin(response_lines(response, includes_status_messages)),
545                        &provider_name,
546                        move |event| mapper.map_event(event),
547                    ))
548                });
549                async move { Ok(future.await?.boxed()) }.boxed()
550            }
551            cloud_llm_client::LanguageModelProvider::Google => {
552                let http_client = self.http_client.clone();
553                let token_provider = self.token_provider.clone();
554                let request =
555                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
556                let auth_context = token_provider.auth_context(cx);
557                let future = self.request_limiter.stream(async move {
558                    let PerformLlmCompletionResponse {
559                        response,
560                        includes_status_messages,
561                    } = Self::perform_llm_completion(
562                        &http_client,
563                        &*token_provider,
564                        auth_context,
565                        app_version,
566                        CompletionBody {
567                            thread_id,
568                            prompt_id,
569                            provider: cloud_llm_client::LanguageModelProvider::Google,
570                            model: request.model.model_id.clone(),
571                            provider_request: serde_json::to_value(&request)
572                                .map_err(|e| anyhow!(e))?,
573                        },
574                    )
575                    .await?;
576
577                    let mut mapper = GoogleEventMapper::new();
578                    Ok(map_cloud_completion_events(
579                        Box::pin(response_lines(response, includes_status_messages)),
580                        &provider_name,
581                        move |event| mapper.map_event(event),
582                    ))
583                });
584                async move { Ok(future.await?.boxed()) }.boxed()
585            }
586        }
587    }
588}
589
590pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
591    token_provider: Arc<TP>,
592    http_client: Arc<HttpClientWithUrl>,
593    app_version: Option<Version>,
594    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
595    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
596    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
597    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
598}
599
600impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
601    pub fn new(
602        token_provider: Arc<TP>,
603        http_client: Arc<HttpClientWithUrl>,
604        app_version: Option<Version>,
605    ) -> Self {
606        Self {
607            token_provider,
608            http_client,
609            app_version,
610            models: Vec::new(),
611            default_model: None,
612            default_fast_model: None,
613            recommended_models: Vec::new(),
614        }
615    }
616
617    pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
618        let http_client = self.http_client.clone();
619        let token_provider = self.token_provider.clone();
620        cx.spawn(async move |this, cx| {
621            let auth_context = token_provider.auth_context(cx);
622            let response =
623                Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
624            this.update(cx, |this, cx| {
625                this.update_models(response);
626                cx.notify();
627            })
628        })
629    }
630
631    async fn fetch_models_request(
632        http_client: &HttpClientWithUrl,
633        token_provider: &TP,
634        auth_context: TP::AuthContext,
635    ) -> Result<ListModelsResponse> {
636        let token = token_provider.acquire_token(auth_context).await?;
637
638        let request = http_client::Request::builder()
639            .method(Method::GET)
640            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
641            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
642            .header("Authorization", format!("Bearer {token}"))
643            .body(AsyncBody::empty())?;
644        let mut response = http_client
645            .send(request)
646            .await
647            .context("failed to send list models request")?;
648
649        if response.status().is_success() {
650            let mut body = String::new();
651            response.body_mut().read_to_string(&mut body).await?;
652            Ok(serde_json::from_str(&body)?)
653        } else {
654            let mut body = String::new();
655            response.body_mut().read_to_string(&mut body).await?;
656            anyhow::bail!(
657                "error listing models.\nStatus: {:?}\nBody: {body}",
658                response.status(),
659            );
660        }
661    }
662
663    pub fn update_models(&mut self, response: ListModelsResponse) {
664        let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
665
666        self.default_model = models
667            .iter()
668            .find(|model| {
669                response
670                    .default_model
671                    .as_ref()
672                    .is_some_and(|default_model_id| &model.id == default_model_id)
673            })
674            .cloned();
675        self.default_fast_model = models
676            .iter()
677            .find(|model| {
678                response
679                    .default_fast_model
680                    .as_ref()
681                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
682            })
683            .cloned();
684        self.recommended_models = response
685            .recommended_models
686            .iter()
687            .filter_map(|id| models.iter().find(|model| &model.id == id))
688            .cloned()
689            .collect();
690        self.models = models;
691    }
692
693    pub fn create_model(
694        &self,
695        model: &Arc<cloud_llm_client::LanguageModel>,
696    ) -> Arc<dyn LanguageModel> {
697        Arc::new(CloudLanguageModel::<TP> {
698            id: LanguageModelId::from(model.id.0.to_string()),
699            model: model.clone(),
700            token_provider: self.token_provider.clone(),
701            http_client: self.http_client.clone(),
702            app_version: self.app_version.clone(),
703            request_limiter: RateLimiter::new(4),
704        })
705    }
706
707    pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
708        &self.models
709    }
710
711    pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
712        self.default_model.as_ref()
713    }
714
715    pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
716        self.default_fast_model.as_ref()
717    }
718
719    pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
720        &self.recommended_models
721    }
722}
723
724pub fn map_cloud_completion_events<T, F>(
725    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
726    provider: &LanguageModelProviderName,
727    mut map_callback: F,
728) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
729where
730    T: DeserializeOwned + 'static,
731    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
732        + Send
733        + 'static,
734{
735    let provider = provider.clone();
736    let mut stream = stream.fuse();
737
738    let mut saw_stream_ended = false;
739
740    let mut done = false;
741    let mut pending = VecDeque::new();
742
743    stream::poll_fn(move |cx| {
744        loop {
745            if let Some(item) = pending.pop_front() {
746                return Poll::Ready(Some(item));
747            }
748
749            if done {
750                return Poll::Ready(None);
751            }
752
753            match stream.poll_next_unpin(cx) {
754                Poll::Ready(Some(event)) => {
755                    let items = match event {
756                        Err(error) => {
757                            vec![Err(LanguageModelCompletionError::from(error))]
758                        }
759                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
760                            saw_stream_ended = true;
761                            vec![]
762                        }
763                        Ok(CompletionEvent::Status(status)) => {
764                            LanguageModelCompletionEvent::from_completion_request_status(
765                                status,
766                                provider.clone(),
767                            )
768                            .transpose()
769                            .map(|event| vec![event])
770                            .unwrap_or_default()
771                        }
772                        Ok(CompletionEvent::Event(event)) => map_callback(event),
773                    };
774                    pending.extend(items);
775                }
776                Poll::Ready(None) => {
777                    done = true;
778
779                    if !saw_stream_ended {
780                        return Poll::Ready(Some(Err(
781                            LanguageModelCompletionError::StreamEndedUnexpectedly {
782                                provider: provider.clone(),
783                            },
784                        )));
785                    }
786                }
787                Poll::Pending => return Poll::Pending,
788            }
789        }
790    })
791    .boxed()
792}
793
794pub fn provider_name(
795    provider: &cloud_llm_client::LanguageModelProvider,
796) -> LanguageModelProviderName {
797    match provider {
798        cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
799        cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
800        cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
801        cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
802    }
803}
804
805pub fn response_lines<T: DeserializeOwned>(
806    response: Response<AsyncBody>,
807    includes_status_messages: bool,
808) -> impl Stream<Item = Result<CompletionEvent<T>>> {
809    futures::stream::try_unfold(
810        (String::new(), BufReader::new(response.into_body())),
811        move |(mut line, mut body)| async move {
812            match body.read_line(&mut line).await {
813                Ok(0) => Ok(None),
814                Ok(_) => {
815                    let event = if includes_status_messages {
816                        serde_json::from_str::<CompletionEvent<T>>(&line)?
817                    } else {
818                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
819                    };
820
821                    line.clear();
822                    Ok(Some((event, (line, body))))
823                }
824                Err(e) => Err(e.into()),
825            }
826        },
827    )
828}
829
830#[cfg(test)]
831mod tests {
832    use super::*;
833    use http_client::http::{HeaderMap, StatusCode};
834    use language_model::LanguageModelCompletionError;
835
836    #[test]
837    fn test_api_error_conversion_with_upstream_http_error() {
838        // upstream_http_error with 503 status should become ServerOverloaded
839        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
840
841        let api_error = ApiError {
842            status: StatusCode::INTERNAL_SERVER_ERROR,
843            body: error_body.to_string(),
844            headers: HeaderMap::new(),
845        };
846
847        let completion_error: LanguageModelCompletionError = api_error.into();
848
849        match completion_error {
850            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
851                assert_eq!(
852                    message,
853                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
854                );
855            }
856            _ => panic!(
857                "Expected UpstreamProviderError for upstream 503, got: {:?}",
858                completion_error
859            ),
860        }
861
862        // upstream_http_error with 500 status should become ApiInternalServerError
863        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
864
865        let api_error = ApiError {
866            status: StatusCode::INTERNAL_SERVER_ERROR,
867            body: error_body.to_string(),
868            headers: HeaderMap::new(),
869        };
870
871        let completion_error: LanguageModelCompletionError = api_error.into();
872
873        match completion_error {
874            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
875                assert_eq!(
876                    message,
877                    "Received an error from the OpenAI API: internal server error"
878                );
879            }
880            _ => panic!(
881                "Expected UpstreamProviderError for upstream 500, got: {:?}",
882                completion_error
883            ),
884        }
885
886        // upstream_http_error with 429 status should become RateLimitExceeded
887        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
888
889        let api_error = ApiError {
890            status: StatusCode::INTERNAL_SERVER_ERROR,
891            body: error_body.to_string(),
892            headers: HeaderMap::new(),
893        };
894
895        let completion_error: LanguageModelCompletionError = api_error.into();
896
897        match completion_error {
898            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
899                assert_eq!(
900                    message,
901                    "Received an error from the Google API: rate limit exceeded"
902                );
903            }
904            _ => panic!(
905                "Expected UpstreamProviderError for upstream 429, got: {:?}",
906                completion_error
907            ),
908        }
909
910        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
911        let error_body = "Regular internal server error";
912
913        let api_error = ApiError {
914            status: StatusCode::INTERNAL_SERVER_ERROR,
915            body: error_body.to_string(),
916            headers: HeaderMap::new(),
917        };
918
919        let completion_error: LanguageModelCompletionError = api_error.into();
920
921        match completion_error {
922            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
923                assert_eq!(provider, PROVIDER_NAME);
924                assert_eq!(message, "Regular internal server error");
925            }
926            _ => panic!(
927                "Expected ApiInternalServerError for regular 500, got: {:?}",
928                completion_error
929            ),
930        }
931
932        // upstream_http_429 format should be converted to UpstreamProviderError
933        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
934
935        let api_error = ApiError {
936            status: StatusCode::INTERNAL_SERVER_ERROR,
937            body: error_body.to_string(),
938            headers: HeaderMap::new(),
939        };
940
941        let completion_error: LanguageModelCompletionError = api_error.into();
942
943        match completion_error {
944            LanguageModelCompletionError::UpstreamProviderError {
945                message,
946                status,
947                retry_after,
948            } => {
949                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
950                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
951                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
952            }
953            _ => panic!(
954                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
955                completion_error
956            ),
957        }
958
959        // Invalid JSON in error body should fall back to regular error handling
960        let error_body = "Not JSON at all";
961
962        let api_error = ApiError {
963            status: StatusCode::INTERNAL_SERVER_ERROR,
964            body: error_body.to_string(),
965            headers: HeaderMap::new(),
966        };
967
968        let completion_error: LanguageModelCompletionError = api_error.into();
969
970        match completion_error {
971            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
972                assert_eq!(provider, PROVIDER_NAME);
973            }
974            _ => panic!(
975                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
976                completion_error
977            ),
978        }
979    }
980}