language_models_cloud.rs

   1use anthropic::AnthropicModelMode;
   2use anyhow::{Context as _, Result, anyhow};
   3use cloud_llm_client::{
   4    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
   5    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
   6    CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
   7    OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
   8    ZED_VERSION_HEADER_NAME,
   9};
  10use futures::{
  11    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  12    future::BoxFuture,
  13    stream::{self, BoxStream},
  14};
  15use google_ai::GoogleModelMode;
  16use gpui::{App, AppContext, AsyncApp, Context, Task};
  17use http_client::http::{HeaderMap, HeaderValue};
  18use http_client::{
  19    AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
  20};
  21use language_model::{
  22    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
  23    LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
  24    LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
  25    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
  26    LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
  27    OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
  28    ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
  29};
  30
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use smol::io::{AsyncReadExt, BufReader};
  35use std::collections::VecDeque;
  36use std::pin::Pin;
  37use std::str::FromStr;
  38use std::sync::Arc;
  39use std::task::Poll;
  40use std::time::Duration;
  41use thiserror::Error;
  42
  43use anthropic::completion::{
  44    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  45};
  46use google_ai::completion::{GoogleEventMapper, into_google};
  47use open_ai::completion::{
  48    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  49    into_open_ai_response,
  50};
  51use x_ai::completion::count_xai_tokens;
  52
  53const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
  54const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
  55
  56/// Trait for acquiring and refreshing LLM authentication tokens.
  57pub trait CloudLlmTokenProvider: Send + Sync {
  58    type AuthContext: Clone + Send + 'static;
  59
  60    fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
  61    fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
  62    fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
  63}
  64
  65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  66#[serde(tag = "type", rename_all = "lowercase")]
  67pub enum ModelMode {
  68    #[default]
  69    Default,
  70    Thinking {
  71        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  72        budget_tokens: Option<u32>,
  73    },
  74}
  75
  76impl From<ModelMode> for AnthropicModelMode {
  77    fn from(value: ModelMode) -> Self {
  78        match value {
  79            ModelMode::Default => AnthropicModelMode::Default,
  80            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  81        }
  82    }
  83}
  84
  85pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
  86    pub id: LanguageModelId,
  87    pub model: Arc<cloud_llm_client::LanguageModel>,
  88    pub token_provider: Arc<TP>,
  89    pub http_client: Arc<HttpClientWithUrl>,
  90    pub app_version: Option<Version>,
  91    pub request_limiter: RateLimiter,
  92}
  93
  94pub struct PerformLlmCompletionResponse {
  95    pub response: Response<AsyncBody>,
  96    pub includes_status_messages: bool,
  97}
  98
  99impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
 100    pub async fn perform_llm_completion(
 101        http_client: &HttpClientWithUrl,
 102        token_provider: &TP,
 103        auth_context: TP::AuthContext,
 104        app_version: Option<Version>,
 105        body: CompletionBody,
 106    ) -> Result<PerformLlmCompletionResponse> {
 107        let mut token = token_provider.acquire_token(auth_context.clone()).await?;
 108        let mut refreshed_token = false;
 109
 110        loop {
 111            let request = http_client::Request::builder()
 112                .method(Method::POST)
 113                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 114                .when_some(app_version.as_ref(), |builder, app_version| {
 115                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 116                })
 117                .header("Content-Type", "application/json")
 118                .header("Authorization", format!("Bearer {token}"))
 119                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 120                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 121                .body(serde_json::to_string(&body)?.into())?;
 122
 123            let mut response = http_client.send(request).await?;
 124            let status = response.status();
 125            if status.is_success() {
 126                let includes_status_messages = response
 127                    .headers()
 128                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 129                    .is_some();
 130
 131                return Ok(PerformLlmCompletionResponse {
 132                    response,
 133                    includes_status_messages,
 134                });
 135            }
 136
 137            if !refreshed_token && needs_llm_token_refresh(&response) {
 138                token = token_provider.refresh_token(auth_context.clone()).await?;
 139                refreshed_token = true;
 140                continue;
 141            }
 142
 143            if status == StatusCode::PAYMENT_REQUIRED {
 144                return Err(anyhow!(PaymentRequiredError));
 145            }
 146
 147            let mut body = String::new();
 148            let headers = response.headers().clone();
 149            response.body_mut().read_to_string(&mut body).await?;
 150            return Err(anyhow!(ApiError {
 151                status,
 152                body,
 153                headers
 154            }));
 155        }
 156    }
 157}
 158
 159fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
 160    response
 161        .headers()
 162        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 163        .is_some()
 164        || response
 165            .headers()
 166            .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
 167            .is_some()
 168}
 169
 170#[derive(Debug, Error)]
 171#[error("cloud language model request failed with status {status}: {body}")]
 172struct ApiError {
 173    status: StatusCode,
 174    body: String,
 175    headers: HeaderMap<HeaderValue>,
 176}
 177
 178/// Represents error responses from Zed's cloud API.
 179///
 180/// Example JSON for an upstream HTTP error:
 181/// ```json
 182/// {
 183///   "code": "upstream_http_error",
 184///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 185///   "upstream_status": 503
 186/// }
 187/// ```
 188#[derive(Debug, serde::Deserialize)]
 189struct CloudApiError {
 190    code: String,
 191    message: String,
 192    #[serde(default)]
 193    #[serde(deserialize_with = "deserialize_optional_status_code")]
 194    upstream_status: Option<StatusCode>,
 195    #[serde(default)]
 196    retry_after: Option<f64>,
 197}
 198
 199fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 200where
 201    D: serde::Deserializer<'de>,
 202{
 203    let opt: Option<u16> = Option::deserialize(deserializer)?;
 204    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 205}
 206
 207impl From<ApiError> for LanguageModelCompletionError {
 208    fn from(error: ApiError) -> Self {
 209        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 210            if cloud_error.code.starts_with("upstream_http_") {
 211                let status = if let Some(status) = cloud_error.upstream_status {
 212                    status
 213                } else if cloud_error.code.ends_with("_error") {
 214                    error.status
 215                } else {
 216                    // If there's a status code in the code string (e.g. "upstream_http_429")
 217                    // then use that; otherwise, see if the JSON contains a status code.
 218                    cloud_error
 219                        .code
 220                        .strip_prefix("upstream_http_")
 221                        .and_then(|code_str| code_str.parse::<u16>().ok())
 222                        .and_then(|code| StatusCode::from_u16(code).ok())
 223                        .unwrap_or(error.status)
 224                };
 225
 226                return LanguageModelCompletionError::UpstreamProviderError {
 227                    message: cloud_error.message,
 228                    status,
 229                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 230                };
 231            }
 232
 233            return LanguageModelCompletionError::from_http_status(
 234                PROVIDER_NAME,
 235                error.status,
 236                cloud_error.message,
 237                None,
 238            );
 239        }
 240
 241        let retry_after = None;
 242        LanguageModelCompletionError::from_http_status(
 243            PROVIDER_NAME,
 244            error.status,
 245            error.body,
 246            retry_after,
 247        )
 248    }
 249}
 250
 251impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
 252    fn id(&self) -> LanguageModelId {
 253        self.id.clone()
 254    }
 255
 256    fn name(&self) -> LanguageModelName {
 257        LanguageModelName::from(self.model.display_name.clone())
 258    }
 259
 260    fn provider_id(&self) -> LanguageModelProviderId {
 261        PROVIDER_ID
 262    }
 263
 264    fn provider_name(&self) -> LanguageModelProviderName {
 265        PROVIDER_NAME
 266    }
 267
 268    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 269        use cloud_llm_client::LanguageModelProvider::*;
 270        match self.model.provider {
 271            Anthropic => ANTHROPIC_PROVIDER_ID,
 272            OpenAi => OPEN_AI_PROVIDER_ID,
 273            Google => GOOGLE_PROVIDER_ID,
 274            XAi => X_AI_PROVIDER_ID,
 275        }
 276    }
 277
 278    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 279        use cloud_llm_client::LanguageModelProvider::*;
 280        match self.model.provider {
 281            Anthropic => ANTHROPIC_PROVIDER_NAME,
 282            OpenAi => OPEN_AI_PROVIDER_NAME,
 283            Google => GOOGLE_PROVIDER_NAME,
 284            XAi => X_AI_PROVIDER_NAME,
 285        }
 286    }
 287
 288    fn is_latest(&self) -> bool {
 289        self.model.is_latest
 290    }
 291
 292    fn supports_tools(&self) -> bool {
 293        self.model.supports_tools
 294    }
 295
 296    fn supports_images(&self) -> bool {
 297        self.model.supports_images
 298    }
 299
 300    fn supports_thinking(&self) -> bool {
 301        self.model.supports_thinking
 302    }
 303
 304    fn supports_fast_mode(&self) -> bool {
 305        self.model.supports_fast_mode
 306    }
 307
 308    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 309        self.model
 310            .supported_effort_levels
 311            .iter()
 312            .map(|effort_level| LanguageModelEffortLevel {
 313                name: effort_level.name.clone().into(),
 314                value: effort_level.value.clone().into(),
 315                is_default: effort_level.is_default.unwrap_or(false),
 316            })
 317            .collect()
 318    }
 319
 320    fn supports_streaming_tools(&self) -> bool {
 321        self.model.supports_streaming_tools
 322    }
 323
 324    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 325        match choice {
 326            LanguageModelToolChoice::Auto
 327            | LanguageModelToolChoice::Any
 328            | LanguageModelToolChoice::None => true,
 329        }
 330    }
 331
 332    fn supports_split_token_display(&self) -> bool {
 333        use cloud_llm_client::LanguageModelProvider::*;
 334        matches!(self.model.provider, OpenAi | XAi)
 335    }
 336
 337    fn telemetry_id(&self) -> String {
 338        format!("zed.dev/{}", self.model.id)
 339    }
 340
 341    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 342        match self.model.provider {
 343            cloud_llm_client::LanguageModelProvider::Anthropic
 344            | cloud_llm_client::LanguageModelProvider::OpenAi => {
 345                LanguageModelToolSchemaFormat::JsonSchema
 346            }
 347            cloud_llm_client::LanguageModelProvider::Google
 348            | cloud_llm_client::LanguageModelProvider::XAi => {
 349                LanguageModelToolSchemaFormat::JsonSchemaSubset
 350            }
 351        }
 352    }
 353
 354    fn max_token_count(&self) -> u64 {
 355        self.model.max_token_count as u64
 356    }
 357
 358    fn max_output_tokens(&self) -> Option<u64> {
 359        Some(self.model.max_output_tokens as u64)
 360    }
 361
 362    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 363        match &self.model.provider {
 364            cloud_llm_client::LanguageModelProvider::Anthropic => {
 365                Some(LanguageModelCacheConfiguration {
 366                    min_total_token: 2_048,
 367                    should_speculate: true,
 368                    max_cache_anchors: 4,
 369                })
 370            }
 371            cloud_llm_client::LanguageModelProvider::OpenAi
 372            | cloud_llm_client::LanguageModelProvider::XAi
 373            | cloud_llm_client::LanguageModelProvider::Google => None,
 374        }
 375    }
 376
 377    fn count_tokens(
 378        &self,
 379        request: LanguageModelRequest,
 380        cx: &App,
 381    ) -> BoxFuture<'static, Result<u64>> {
 382        match self.model.provider {
 383            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 384                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 385                .boxed(),
 386            cloud_llm_client::LanguageModelProvider::OpenAi => {
 387                let model = match open_ai::Model::from_id(&self.model.id.0) {
 388                    Ok(model) => model,
 389                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 390                };
 391                cx.background_spawn(async move { count_open_ai_tokens(request, model) })
 392                    .boxed()
 393            }
 394            cloud_llm_client::LanguageModelProvider::XAi => {
 395                let model = match x_ai::Model::from_id(&self.model.id.0) {
 396                    Ok(model) => model,
 397                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 398                };
 399                cx.background_spawn(async move { count_xai_tokens(request, model) })
 400                    .boxed()
 401            }
 402            cloud_llm_client::LanguageModelProvider::Google => {
 403                let http_client = self.http_client.clone();
 404                let token_provider = self.token_provider.clone();
 405                let model_id = self.model.id.to_string();
 406                let generate_content_request =
 407                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 408                let auth_context = token_provider.auth_context(cx);
 409                async move {
 410                    let token = token_provider.acquire_token(auth_context).await?;
 411
 412                    let request_body = CountTokensBody {
 413                        provider: cloud_llm_client::LanguageModelProvider::Google,
 414                        model: model_id,
 415                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 416                            generate_content_request,
 417                        })?,
 418                    };
 419                    let request = http_client::Request::builder()
 420                        .method(Method::POST)
 421                        .uri(
 422                            http_client
 423                                .build_zed_llm_url("/count_tokens", &[])?
 424                                .as_ref(),
 425                        )
 426                        .header("Content-Type", "application/json")
 427                        .header("Authorization", format!("Bearer {token}"))
 428                        .body(serde_json::to_string(&request_body)?.into())?;
 429                    let mut response = http_client.send(request).await?;
 430                    let status = response.status();
 431                    let headers = response.headers().clone();
 432                    let mut response_body = String::new();
 433                    response
 434                        .body_mut()
 435                        .read_to_string(&mut response_body)
 436                        .await?;
 437
 438                    if status.is_success() {
 439                        let response_body: CountTokensResponse =
 440                            serde_json::from_str(&response_body)?;
 441
 442                        Ok(response_body.tokens as u64)
 443                    } else {
 444                        Err(anyhow!(ApiError {
 445                            status,
 446                            body: response_body,
 447                            headers
 448                        }))
 449                    }
 450                }
 451                .boxed()
 452            }
 453        }
 454    }
 455
 456    fn stream_completion(
 457        &self,
 458        request: LanguageModelRequest,
 459        cx: &AsyncApp,
 460    ) -> BoxFuture<
 461        'static,
 462        Result<
 463            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 464            LanguageModelCompletionError,
 465        >,
 466    > {
 467        let thread_id = request.thread_id.clone();
 468        let prompt_id = request.prompt_id.clone();
 469        let app_version = self.app_version.clone();
 470        let thinking_allowed = request.thinking_allowed;
 471        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 472        let provider_name = provider_name(&self.model.provider);
 473        match self.model.provider {
 474            cloud_llm_client::LanguageModelProvider::Anthropic => {
 475                let effort = request
 476                    .thinking_effort
 477                    .as_ref()
 478                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 479
 480                let mut request = into_anthropic(
 481                    request,
 482                    self.model.id.to_string(),
 483                    1.0,
 484                    self.model.max_output_tokens as u64,
 485                    if enable_thinking {
 486                        AnthropicModelMode::Thinking {
 487                            budget_tokens: Some(4_096),
 488                        }
 489                    } else {
 490                        AnthropicModelMode::Default
 491                    },
 492                );
 493
 494                if enable_thinking && effort.is_some() {
 495                    request.thinking = Some(anthropic::Thinking::Adaptive);
 496                    request.output_config = Some(anthropic::OutputConfig { effort });
 497                }
 498
 499                if !self.model.supports_fast_mode {
 500                    request.speed = None;
 501                }
 502
 503                let http_client = self.http_client.clone();
 504                let token_provider = self.token_provider.clone();
 505                let auth_context = token_provider.auth_context(cx);
 506                let future = self.request_limiter.stream(async move {
 507                    let PerformLlmCompletionResponse {
 508                        response,
 509                        includes_status_messages,
 510                    } = Self::perform_llm_completion(
 511                        &http_client,
 512                        &*token_provider,
 513                        auth_context,
 514                        app_version,
 515                        CompletionBody {
 516                            thread_id,
 517                            prompt_id,
 518                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 519                            model: request.model.clone(),
 520                            provider_request: serde_json::to_value(&request)
 521                                .map_err(|e| anyhow!(e))?,
 522                        },
 523                    )
 524                    .await
 525                    .map_err(|err| match err.downcast::<ApiError>() {
 526                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 527                        Err(err) => anyhow!(err),
 528                    })?;
 529
 530                    let mut mapper = AnthropicEventMapper::new();
 531                    Ok(map_cloud_completion_events(
 532                        Box::pin(response_lines(response, includes_status_messages)),
 533                        &provider_name,
 534                        move |event| mapper.map_event(event),
 535                    ))
 536                });
 537                async move { Ok(future.await?.boxed()) }.boxed()
 538            }
 539            cloud_llm_client::LanguageModelProvider::OpenAi => {
 540                let http_client = self.http_client.clone();
 541                let token_provider = self.token_provider.clone();
 542                let effort = request
 543                    .thinking_effort
 544                    .as_ref()
 545                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 546
 547                let mut request = into_open_ai_response(
 548                    request,
 549                    &self.model.id.0,
 550                    self.model.supports_parallel_tool_calls,
 551                    true,
 552                    None,
 553                    None,
 554                );
 555
 556                if enable_thinking && let Some(effort) = effort {
 557                    request.reasoning = Some(open_ai::responses::ReasoningConfig {
 558                        effort,
 559                        summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
 560                    });
 561                }
 562
 563                let auth_context = token_provider.auth_context(cx);
 564                let future = self.request_limiter.stream(async move {
 565                    let PerformLlmCompletionResponse {
 566                        response,
 567                        includes_status_messages,
 568                    } = Self::perform_llm_completion(
 569                        &http_client,
 570                        &*token_provider,
 571                        auth_context,
 572                        app_version,
 573                        CompletionBody {
 574                            thread_id,
 575                            prompt_id,
 576                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 577                            model: request.model.clone(),
 578                            provider_request: serde_json::to_value(&request)
 579                                .map_err(|e| anyhow!(e))?,
 580                        },
 581                    )
 582                    .await?;
 583
 584                    let mut mapper = OpenAiResponseEventMapper::new();
 585                    Ok(map_cloud_completion_events(
 586                        Box::pin(response_lines(response, includes_status_messages)),
 587                        &provider_name,
 588                        move |event| mapper.map_event(event),
 589                    ))
 590                });
 591                async move { Ok(future.await?.boxed()) }.boxed()
 592            }
 593            cloud_llm_client::LanguageModelProvider::XAi => {
 594                let http_client = self.http_client.clone();
 595                let token_provider = self.token_provider.clone();
 596                let request = into_open_ai(
 597                    request,
 598                    &self.model.id.0,
 599                    self.model.supports_parallel_tool_calls,
 600                    false,
 601                    None,
 602                    None,
 603                    false,
 604                );
 605                let auth_context = token_provider.auth_context(cx);
 606                let future = self.request_limiter.stream(async move {
 607                    let PerformLlmCompletionResponse {
 608                        response,
 609                        includes_status_messages,
 610                    } = Self::perform_llm_completion(
 611                        &http_client,
 612                        &*token_provider,
 613                        auth_context,
 614                        app_version,
 615                        CompletionBody {
 616                            thread_id,
 617                            prompt_id,
 618                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 619                            model: request.model.clone(),
 620                            provider_request: serde_json::to_value(&request)
 621                                .map_err(|e| anyhow!(e))?,
 622                        },
 623                    )
 624                    .await?;
 625
 626                    let mut mapper = OpenAiEventMapper::new();
 627                    Ok(map_cloud_completion_events(
 628                        Box::pin(response_lines(response, includes_status_messages)),
 629                        &provider_name,
 630                        move |event| mapper.map_event(event),
 631                    ))
 632                });
 633                async move { Ok(future.await?.boxed()) }.boxed()
 634            }
 635            cloud_llm_client::LanguageModelProvider::Google => {
 636                let http_client = self.http_client.clone();
 637                let token_provider = self.token_provider.clone();
 638                let request =
 639                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 640                let auth_context = token_provider.auth_context(cx);
 641                let future = self.request_limiter.stream(async move {
 642                    let PerformLlmCompletionResponse {
 643                        response,
 644                        includes_status_messages,
 645                    } = Self::perform_llm_completion(
 646                        &http_client,
 647                        &*token_provider,
 648                        auth_context,
 649                        app_version,
 650                        CompletionBody {
 651                            thread_id,
 652                            prompt_id,
 653                            provider: cloud_llm_client::LanguageModelProvider::Google,
 654                            model: request.model.model_id.clone(),
 655                            provider_request: serde_json::to_value(&request)
 656                                .map_err(|e| anyhow!(e))?,
 657                        },
 658                    )
 659                    .await?;
 660
 661                    let mut mapper = GoogleEventMapper::new();
 662                    Ok(map_cloud_completion_events(
 663                        Box::pin(response_lines(response, includes_status_messages)),
 664                        &provider_name,
 665                        move |event| mapper.map_event(event),
 666                    ))
 667                });
 668                async move { Ok(future.await?.boxed()) }.boxed()
 669            }
 670        }
 671    }
 672}
 673
 674pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
 675    token_provider: Arc<TP>,
 676    http_client: Arc<HttpClientWithUrl>,
 677    app_version: Option<Version>,
 678    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 679    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 680    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 681    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 682}
 683
 684impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
 685    pub fn new(
 686        token_provider: Arc<TP>,
 687        http_client: Arc<HttpClientWithUrl>,
 688        app_version: Option<Version>,
 689    ) -> Self {
 690        Self {
 691            token_provider,
 692            http_client,
 693            app_version,
 694            models: Vec::new(),
 695            default_model: None,
 696            default_fast_model: None,
 697            recommended_models: Vec::new(),
 698        }
 699    }
 700
 701    pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 702        let http_client = self.http_client.clone();
 703        let token_provider = self.token_provider.clone();
 704        cx.spawn(async move |this, cx| {
 705            let auth_context = token_provider.auth_context(cx);
 706            let response =
 707                Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
 708            this.update(cx, |this, cx| {
 709                this.update_models(response);
 710                cx.notify();
 711            })
 712        })
 713    }
 714
 715    async fn fetch_models_request(
 716        http_client: &HttpClientWithUrl,
 717        token_provider: &TP,
 718        auth_context: TP::AuthContext,
 719    ) -> Result<ListModelsResponse> {
 720        let token = token_provider.acquire_token(auth_context).await?;
 721
 722        let request = http_client::Request::builder()
 723            .method(Method::GET)
 724            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 725            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 726            .header("Authorization", format!("Bearer {token}"))
 727            .body(AsyncBody::empty())?;
 728        let mut response = http_client
 729            .send(request)
 730            .await
 731            .context("failed to send list models request")?;
 732
 733        if response.status().is_success() {
 734            let mut body = String::new();
 735            response.body_mut().read_to_string(&mut body).await?;
 736            Ok(serde_json::from_str(&body)?)
 737        } else {
 738            let mut body = String::new();
 739            response.body_mut().read_to_string(&mut body).await?;
 740            anyhow::bail!(
 741                "error listing models.\nStatus: {:?}\nBody: {body}",
 742                response.status(),
 743            );
 744        }
 745    }
 746
 747    pub fn update_models(&mut self, response: ListModelsResponse) {
 748        let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
 749
 750        self.default_model = models
 751            .iter()
 752            .find(|model| {
 753                response
 754                    .default_model
 755                    .as_ref()
 756                    .is_some_and(|default_model_id| &model.id == default_model_id)
 757            })
 758            .cloned();
 759        self.default_fast_model = models
 760            .iter()
 761            .find(|model| {
 762                response
 763                    .default_fast_model
 764                    .as_ref()
 765                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 766            })
 767            .cloned();
 768        self.recommended_models = response
 769            .recommended_models
 770            .iter()
 771            .filter_map(|id| models.iter().find(|model| &model.id == id))
 772            .cloned()
 773            .collect();
 774        self.models = models;
 775    }
 776
 777    pub fn create_model(
 778        &self,
 779        model: &Arc<cloud_llm_client::LanguageModel>,
 780    ) -> Arc<dyn LanguageModel> {
 781        Arc::new(CloudLanguageModel::<TP> {
 782            id: LanguageModelId::from(model.id.0.to_string()),
 783            model: model.clone(),
 784            token_provider: self.token_provider.clone(),
 785            http_client: self.http_client.clone(),
 786            app_version: self.app_version.clone(),
 787            request_limiter: RateLimiter::new(4),
 788        })
 789    }
 790
 791    pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
 792        &self.models
 793    }
 794
 795    pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
 796        self.default_model.as_ref()
 797    }
 798
 799    pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
 800        self.default_fast_model.as_ref()
 801    }
 802
 803    pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
 804        &self.recommended_models
 805    }
 806}
 807
 808pub fn map_cloud_completion_events<T, F>(
 809    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 810    provider: &LanguageModelProviderName,
 811    mut map_callback: F,
 812) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 813where
 814    T: DeserializeOwned + 'static,
 815    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 816        + Send
 817        + 'static,
 818{
 819    let provider = provider.clone();
 820    let mut stream = stream.fuse();
 821
 822    let mut saw_stream_ended = false;
 823
 824    let mut done = false;
 825    let mut pending = VecDeque::new();
 826
 827    stream::poll_fn(move |cx| {
 828        loop {
 829            if let Some(item) = pending.pop_front() {
 830                return Poll::Ready(Some(item));
 831            }
 832
 833            if done {
 834                return Poll::Ready(None);
 835            }
 836
 837            match stream.poll_next_unpin(cx) {
 838                Poll::Ready(Some(event)) => {
 839                    let items = match event {
 840                        Err(error) => {
 841                            vec![Err(LanguageModelCompletionError::from(error))]
 842                        }
 843                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
 844                            saw_stream_ended = true;
 845                            vec![]
 846                        }
 847                        Ok(CompletionEvent::Status(status)) => {
 848                            LanguageModelCompletionEvent::from_completion_request_status(
 849                                status,
 850                                provider.clone(),
 851                            )
 852                            .transpose()
 853                            .map(|event| vec![event])
 854                            .unwrap_or_default()
 855                        }
 856                        Ok(CompletionEvent::Event(event)) => map_callback(event),
 857                    };
 858                    pending.extend(items);
 859                }
 860                Poll::Ready(None) => {
 861                    done = true;
 862
 863                    if !saw_stream_ended {
 864                        return Poll::Ready(Some(Err(
 865                            LanguageModelCompletionError::StreamEndedUnexpectedly {
 866                                provider: provider.clone(),
 867                            },
 868                        )));
 869                    }
 870                }
 871                Poll::Pending => return Poll::Pending,
 872            }
 873        }
 874    })
 875    .boxed()
 876}
 877
 878pub fn provider_name(
 879    provider: &cloud_llm_client::LanguageModelProvider,
 880) -> LanguageModelProviderName {
 881    match provider {
 882        cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
 883        cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
 884        cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
 885        cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
 886    }
 887}
 888
 889pub fn response_lines<T: DeserializeOwned>(
 890    response: Response<AsyncBody>,
 891    includes_status_messages: bool,
 892) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 893    futures::stream::try_unfold(
 894        (String::new(), BufReader::new(response.into_body())),
 895        move |(mut line, mut body)| async move {
 896            match body.read_line(&mut line).await {
 897                Ok(0) => Ok(None),
 898                Ok(_) => {
 899                    let event = if includes_status_messages {
 900                        serde_json::from_str::<CompletionEvent<T>>(&line)?
 901                    } else {
 902                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
 903                    };
 904
 905                    line.clear();
 906                    Ok(Some((event, (line, body))))
 907                }
 908                Err(e) => Err(e.into()),
 909            }
 910        },
 911    )
 912}
 913
 914#[cfg(test)]
 915mod tests {
 916    use super::*;
 917    use http_client::http::{HeaderMap, StatusCode};
 918    use language_model::LanguageModelCompletionError;
 919
 920    #[test]
 921    fn test_api_error_conversion_with_upstream_http_error() {
 922        // upstream_http_error with 503 status should become ServerOverloaded
 923        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
 924
 925        let api_error = ApiError {
 926            status: StatusCode::INTERNAL_SERVER_ERROR,
 927            body: error_body.to_string(),
 928            headers: HeaderMap::new(),
 929        };
 930
 931        let completion_error: LanguageModelCompletionError = api_error.into();
 932
 933        match completion_error {
 934            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
 935                assert_eq!(
 936                    message,
 937                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
 938                );
 939            }
 940            _ => panic!(
 941                "Expected UpstreamProviderError for upstream 503, got: {:?}",
 942                completion_error
 943            ),
 944        }
 945
 946        // upstream_http_error with 500 status should become ApiInternalServerError
 947        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
 948
 949        let api_error = ApiError {
 950            status: StatusCode::INTERNAL_SERVER_ERROR,
 951            body: error_body.to_string(),
 952            headers: HeaderMap::new(),
 953        };
 954
 955        let completion_error: LanguageModelCompletionError = api_error.into();
 956
 957        match completion_error {
 958            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
 959                assert_eq!(
 960                    message,
 961                    "Received an error from the OpenAI API: internal server error"
 962                );
 963            }
 964            _ => panic!(
 965                "Expected UpstreamProviderError for upstream 500, got: {:?}",
 966                completion_error
 967            ),
 968        }
 969
 970        // upstream_http_error with 429 status should become RateLimitExceeded
 971        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
 972
 973        let api_error = ApiError {
 974            status: StatusCode::INTERNAL_SERVER_ERROR,
 975            body: error_body.to_string(),
 976            headers: HeaderMap::new(),
 977        };
 978
 979        let completion_error: LanguageModelCompletionError = api_error.into();
 980
 981        match completion_error {
 982            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
 983                assert_eq!(
 984                    message,
 985                    "Received an error from the Google API: rate limit exceeded"
 986                );
 987            }
 988            _ => panic!(
 989                "Expected UpstreamProviderError for upstream 429, got: {:?}",
 990                completion_error
 991            ),
 992        }
 993
 994        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
 995        let error_body = "Regular internal server error";
 996
 997        let api_error = ApiError {
 998            status: StatusCode::INTERNAL_SERVER_ERROR,
 999            body: error_body.to_string(),
1000            headers: HeaderMap::new(),
1001        };
1002
1003        let completion_error: LanguageModelCompletionError = api_error.into();
1004
1005        match completion_error {
1006            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1007                assert_eq!(provider, PROVIDER_NAME);
1008                assert_eq!(message, "Regular internal server error");
1009            }
1010            _ => panic!(
1011                "Expected ApiInternalServerError for regular 500, got: {:?}",
1012                completion_error
1013            ),
1014        }
1015
1016        // upstream_http_429 format should be converted to UpstreamProviderError
1017        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1018
1019        let api_error = ApiError {
1020            status: StatusCode::INTERNAL_SERVER_ERROR,
1021            body: error_body.to_string(),
1022            headers: HeaderMap::new(),
1023        };
1024
1025        let completion_error: LanguageModelCompletionError = api_error.into();
1026
1027        match completion_error {
1028            LanguageModelCompletionError::UpstreamProviderError {
1029                message,
1030                status,
1031                retry_after,
1032            } => {
1033                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1034                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1035                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1036            }
1037            _ => panic!(
1038                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1039                completion_error
1040            ),
1041        }
1042
1043        // Invalid JSON in error body should fall back to regular error handling
1044        let error_body = "Not JSON at all";
1045
1046        let api_error = ApiError {
1047            status: StatusCode::INTERNAL_SERVER_ERROR,
1048            body: error_body.to_string(),
1049            headers: HeaderMap::new(),
1050        };
1051
1052        let completion_error: LanguageModelCompletionError = api_error.into();
1053
1054        match completion_error {
1055            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1056                assert_eq!(provider, PROVIDER_NAME);
1057            }
1058            _ => panic!(
1059                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1060                completion_error
1061            ),
1062        }
1063    }
1064}