language_models_cloud.rs

   1use anthropic::AnthropicModelMode;
   2use anyhow::{Context as _, Result, anyhow};
   3use cloud_llm_client::{
   4    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
   5    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
   6    CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
   7    OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
   8    ZED_VERSION_HEADER_NAME,
   9};
  10use futures::{
  11    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  12    future::BoxFuture,
  13    stream::{self, BoxStream},
  14};
  15use google_ai::GoogleModelMode;
  16use gpui::{App, AppContext, AsyncApp, Context, Task};
  17use http_client::http::{HeaderMap, HeaderValue};
  18use http_client::{
  19    AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
  20};
  21use language_model::{
  22    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
  23    LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
  24    LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
  25    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
  26    LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
  27    OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
  28    ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
  29};
  30
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use smol::io::{AsyncReadExt, BufReader};
  35use std::collections::VecDeque;
  36use std::pin::Pin;
  37use std::str::FromStr;
  38use std::sync::Arc;
  39use std::task::Poll;
  40use std::time::Duration;
  41use thiserror::Error;
  42
  43use anthropic::completion::{
  44    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  45};
  46use google_ai::completion::{GoogleEventMapper, into_google};
  47use open_ai::completion::{
  48    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  49    into_open_ai_response,
  50};
  51use x_ai::completion::count_xai_tokens;
  52
  53const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
  54const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
  55
  56/// Trait for acquiring and refreshing LLM authentication tokens.
  57pub trait CloudLlmTokenProvider: Send + Sync {
  58    type AuthContext: Clone + Send + 'static;
  59
  60    fn auth_context(&self, cx: &AsyncApp) -> Self::AuthContext;
  61    fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
  62    fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
  63}
  64
  65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  66#[serde(tag = "type", rename_all = "lowercase")]
  67pub enum ModelMode {
  68    #[default]
  69    Default,
  70    Thinking {
  71        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  72        budget_tokens: Option<u32>,
  73    },
  74}
  75
  76impl From<ModelMode> for AnthropicModelMode {
  77    fn from(value: ModelMode) -> Self {
  78        match value {
  79            ModelMode::Default => AnthropicModelMode::Default,
  80            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  81        }
  82    }
  83}
  84
  85pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
  86    pub id: LanguageModelId,
  87    pub model: Arc<cloud_llm_client::LanguageModel>,
  88    pub token_provider: Arc<TP>,
  89    pub http_client: Arc<HttpClientWithUrl>,
  90    pub app_version: Option<Version>,
  91    pub request_limiter: RateLimiter,
  92}
  93
  94pub struct PerformLlmCompletionResponse {
  95    pub response: Response<AsyncBody>,
  96    pub includes_status_messages: bool,
  97}
  98
  99impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
 100    pub async fn perform_llm_completion(
 101        http_client: &HttpClientWithUrl,
 102        token_provider: &TP,
 103        auth_context: TP::AuthContext,
 104        app_version: Option<Version>,
 105        body: CompletionBody,
 106    ) -> Result<PerformLlmCompletionResponse> {
 107        let mut token = token_provider.acquire_token(auth_context.clone()).await?;
 108        let mut refreshed_token = false;
 109
 110        loop {
 111            let request = http_client::Request::builder()
 112                .method(Method::POST)
 113                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 114                .when_some(app_version.as_ref(), |builder, app_version| {
 115                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 116                })
 117                .header("Content-Type", "application/json")
 118                .header("Authorization", format!("Bearer {token}"))
 119                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 120                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 121                .body(serde_json::to_string(&body)?.into())?;
 122
 123            let mut response = http_client.send(request).await?;
 124            let status = response.status();
 125            if status.is_success() {
 126                let includes_status_messages = response
 127                    .headers()
 128                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 129                    .is_some();
 130
 131                return Ok(PerformLlmCompletionResponse {
 132                    response,
 133                    includes_status_messages,
 134                });
 135            }
 136
 137            if !refreshed_token && needs_llm_token_refresh(&response) {
 138                token = token_provider.refresh_token(auth_context.clone()).await?;
 139                refreshed_token = true;
 140                continue;
 141            }
 142
 143            if status == StatusCode::PAYMENT_REQUIRED {
 144                return Err(anyhow!(PaymentRequiredError));
 145            }
 146
 147            let mut body = String::new();
 148            let headers = response.headers().clone();
 149            response.body_mut().read_to_string(&mut body).await?;
 150            return Err(anyhow!(ApiError {
 151                status,
 152                body,
 153                headers
 154            }));
 155        }
 156    }
 157}
 158
 159fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
 160    response
 161        .headers()
 162        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 163        .is_some()
 164        || response
 165            .headers()
 166            .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
 167            .is_some()
 168}
 169
 170#[derive(Debug, Error)]
 171#[error("cloud language model request failed with status {status}: {body}")]
 172struct ApiError {
 173    status: StatusCode,
 174    body: String,
 175    headers: HeaderMap<HeaderValue>,
 176}
 177
 178/// Represents error responses from Zed's cloud API.
 179///
 180/// Example JSON for an upstream HTTP error:
 181/// ```json
 182/// {
 183///   "code": "upstream_http_error",
 184///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 185///   "upstream_status": 503
 186/// }
 187/// ```
 188#[derive(Debug, serde::Deserialize)]
 189struct CloudApiError {
 190    code: String,
 191    message: String,
 192    #[serde(default)]
 193    #[serde(deserialize_with = "deserialize_optional_status_code")]
 194    upstream_status: Option<StatusCode>,
 195    #[serde(default)]
 196    retry_after: Option<f64>,
 197}
 198
 199fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 200where
 201    D: serde::Deserializer<'de>,
 202{
 203    let opt: Option<u16> = Option::deserialize(deserializer)?;
 204    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 205}
 206
 207impl From<ApiError> for LanguageModelCompletionError {
 208    fn from(error: ApiError) -> Self {
 209        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 210            if cloud_error.code.starts_with("upstream_http_") {
 211                let status = if let Some(status) = cloud_error.upstream_status {
 212                    status
 213                } else if cloud_error.code.ends_with("_error") {
 214                    error.status
 215                } else {
 216                    // If there's a status code in the code string (e.g. "upstream_http_429")
 217                    // then use that; otherwise, see if the JSON contains a status code.
 218                    cloud_error
 219                        .code
 220                        .strip_prefix("upstream_http_")
 221                        .and_then(|code_str| code_str.parse::<u16>().ok())
 222                        .and_then(|code| StatusCode::from_u16(code).ok())
 223                        .unwrap_or(error.status)
 224                };
 225
 226                return LanguageModelCompletionError::UpstreamProviderError {
 227                    message: cloud_error.message,
 228                    status,
 229                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 230                };
 231            }
 232
 233            return LanguageModelCompletionError::from_http_status(
 234                PROVIDER_NAME,
 235                error.status,
 236                cloud_error.message,
 237                None,
 238            );
 239        }
 240
 241        let retry_after = None;
 242        LanguageModelCompletionError::from_http_status(
 243            PROVIDER_NAME,
 244            error.status,
 245            error.body,
 246            retry_after,
 247        )
 248    }
 249}
 250
 251impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
 252    fn id(&self) -> LanguageModelId {
 253        self.id.clone()
 254    }
 255
 256    fn name(&self) -> LanguageModelName {
 257        LanguageModelName::from(self.model.display_name.clone())
 258    }
 259
 260    fn provider_id(&self) -> LanguageModelProviderId {
 261        PROVIDER_ID
 262    }
 263
 264    fn provider_name(&self) -> LanguageModelProviderName {
 265        PROVIDER_NAME
 266    }
 267
 268    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 269        use cloud_llm_client::LanguageModelProvider::*;
 270        match self.model.provider {
 271            Anthropic => ANTHROPIC_PROVIDER_ID,
 272            OpenAi => OPEN_AI_PROVIDER_ID,
 273            Google => GOOGLE_PROVIDER_ID,
 274            XAi => X_AI_PROVIDER_ID,
 275        }
 276    }
 277
 278    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 279        use cloud_llm_client::LanguageModelProvider::*;
 280        match self.model.provider {
 281            Anthropic => ANTHROPIC_PROVIDER_NAME,
 282            OpenAi => OPEN_AI_PROVIDER_NAME,
 283            Google => GOOGLE_PROVIDER_NAME,
 284            XAi => X_AI_PROVIDER_NAME,
 285        }
 286    }
 287
 288    fn is_latest(&self) -> bool {
 289        self.model.is_latest
 290    }
 291
 292    fn supports_tools(&self) -> bool {
 293        self.model.supports_tools
 294    }
 295
 296    fn supports_images(&self) -> bool {
 297        self.model.supports_images
 298    }
 299
 300    fn supports_thinking(&self) -> bool {
 301        self.model.supports_thinking
 302    }
 303
 304    fn supports_fast_mode(&self) -> bool {
 305        self.model.supports_fast_mode
 306    }
 307
 308    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 309        self.model
 310            .supported_effort_levels
 311            .iter()
 312            .map(|effort_level| LanguageModelEffortLevel {
 313                name: effort_level.name.clone().into(),
 314                value: effort_level.value.clone().into(),
 315                is_default: effort_level.is_default.unwrap_or(false),
 316            })
 317            .collect()
 318    }
 319
 320    fn supports_streaming_tools(&self) -> bool {
 321        self.model.supports_streaming_tools
 322    }
 323
 324    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 325        match choice {
 326            LanguageModelToolChoice::Auto
 327            | LanguageModelToolChoice::Any
 328            | LanguageModelToolChoice::None => true,
 329        }
 330    }
 331
 332    fn supports_split_token_display(&self) -> bool {
 333        use cloud_llm_client::LanguageModelProvider::*;
 334        matches!(self.model.provider, OpenAi | XAi)
 335    }
 336
 337    fn telemetry_id(&self) -> String {
 338        format!("zed.dev/{}", self.model.id)
 339    }
 340
 341    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 342        match self.model.provider {
 343            cloud_llm_client::LanguageModelProvider::Anthropic
 344            | cloud_llm_client::LanguageModelProvider::OpenAi => {
 345                LanguageModelToolSchemaFormat::JsonSchema
 346            }
 347            cloud_llm_client::LanguageModelProvider::Google
 348            | cloud_llm_client::LanguageModelProvider::XAi => {
 349                LanguageModelToolSchemaFormat::JsonSchemaSubset
 350            }
 351        }
 352    }
 353
 354    fn max_token_count(&self) -> u64 {
 355        self.model.max_token_count as u64
 356    }
 357
 358    fn max_output_tokens(&self) -> Option<u64> {
 359        Some(self.model.max_output_tokens as u64)
 360    }
 361
 362    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 363        match &self.model.provider {
 364            cloud_llm_client::LanguageModelProvider::Anthropic => {
 365                Some(LanguageModelCacheConfiguration {
 366                    min_total_token: 2_048,
 367                    should_speculate: true,
 368                    max_cache_anchors: 4,
 369                })
 370            }
 371            cloud_llm_client::LanguageModelProvider::OpenAi
 372            | cloud_llm_client::LanguageModelProvider::XAi
 373            | cloud_llm_client::LanguageModelProvider::Google => None,
 374        }
 375    }
 376
 377    fn count_tokens(
 378        &self,
 379        request: LanguageModelRequest,
 380        cx: &App,
 381    ) -> BoxFuture<'static, Result<u64>> {
 382        match self.model.provider {
 383            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 384                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 385                .boxed(),
 386            cloud_llm_client::LanguageModelProvider::OpenAi => {
 387                let model = match open_ai::Model::from_id(&self.model.id.0) {
 388                    Ok(model) => model,
 389                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 390                };
 391                cx.background_spawn(async move { count_open_ai_tokens(request, model) })
 392                    .boxed()
 393            }
 394            cloud_llm_client::LanguageModelProvider::XAi => {
 395                let model = match x_ai::Model::from_id(&self.model.id.0) {
 396                    Ok(model) => model,
 397                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 398                };
 399                cx.background_spawn(async move { count_xai_tokens(request, model) })
 400                    .boxed()
 401            }
 402            cloud_llm_client::LanguageModelProvider::Google => {
 403                let http_client = self.http_client.clone();
 404                let token_provider = self.token_provider.clone();
 405                let model_id = self.model.id.to_string();
 406                let generate_content_request =
 407                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 408                let auth_context = token_provider.auth_context(&cx.to_async());
 409                async move {
 410                    let token = token_provider.acquire_token(auth_context).await?;
 411
 412                    let request_body = CountTokensBody {
 413                        provider: cloud_llm_client::LanguageModelProvider::Google,
 414                        model: model_id,
 415                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 416                            generate_content_request,
 417                        })?,
 418                    };
 419                    let request = http_client::Request::builder()
 420                        .method(Method::POST)
 421                        .uri(
 422                            http_client
 423                                .build_zed_llm_url("/count_tokens", &[])?
 424                                .as_ref(),
 425                        )
 426                        .header("Content-Type", "application/json")
 427                        .header("Authorization", format!("Bearer {token}"))
 428                        .body(serde_json::to_string(&request_body)?.into())?;
 429                    let mut response = http_client.send(request).await?;
 430                    let status = response.status();
 431                    let headers = response.headers().clone();
 432                    let mut response_body = String::new();
 433                    response
 434                        .body_mut()
 435                        .read_to_string(&mut response_body)
 436                        .await?;
 437
 438                    if status.is_success() {
 439                        let response_body: CountTokensResponse =
 440                            serde_json::from_str(&response_body)?;
 441
 442                        Ok(response_body.tokens as u64)
 443                    } else {
 444                        Err(anyhow!(ApiError {
 445                            status,
 446                            body: response_body,
 447                            headers
 448                        }))
 449                    }
 450                }
 451                .boxed()
 452            }
 453        }
 454    }
 455
 456    fn stream_completion(
 457        &self,
 458        request: LanguageModelRequest,
 459        cx: &AsyncApp,
 460    ) -> BoxFuture<
 461        'static,
 462        Result<
 463            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 464            LanguageModelCompletionError,
 465        >,
 466    > {
 467        let thread_id = request.thread_id.clone();
 468        let prompt_id = request.prompt_id.clone();
 469        let app_version = self.app_version.clone();
 470        let thinking_allowed = request.thinking_allowed;
 471        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 472        let provider_name = provider_name(&self.model.provider);
 473        match self.model.provider {
 474            cloud_llm_client::LanguageModelProvider::Anthropic => {
 475                let effort = request
 476                    .thinking_effort
 477                    .as_ref()
 478                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 479
 480                let mut request = into_anthropic(
 481                    request,
 482                    self.model.id.to_string(),
 483                    1.0,
 484                    self.model.max_output_tokens as u64,
 485                    if enable_thinking {
 486                        AnthropicModelMode::Thinking {
 487                            budget_tokens: Some(4_096),
 488                        }
 489                    } else {
 490                        AnthropicModelMode::Default
 491                    },
 492                );
 493
 494                if enable_thinking && effort.is_some() {
 495                    request.thinking = Some(anthropic::Thinking::Adaptive);
 496                    request.output_config = Some(anthropic::OutputConfig { effort });
 497                }
 498
 499                let http_client = self.http_client.clone();
 500                let token_provider = self.token_provider.clone();
 501                let auth_context = token_provider.auth_context(cx);
 502                let future = self.request_limiter.stream(async move {
 503                    let PerformLlmCompletionResponse {
 504                        response,
 505                        includes_status_messages,
 506                    } = Self::perform_llm_completion(
 507                        &http_client,
 508                        &*token_provider,
 509                        auth_context,
 510                        app_version,
 511                        CompletionBody {
 512                            thread_id,
 513                            prompt_id,
 514                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 515                            model: request.model.clone(),
 516                            provider_request: serde_json::to_value(&request)
 517                                .map_err(|e| anyhow!(e))?,
 518                        },
 519                    )
 520                    .await
 521                    .map_err(|err| match err.downcast::<ApiError>() {
 522                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 523                        Err(err) => anyhow!(err),
 524                    })?;
 525
 526                    let mut mapper = AnthropicEventMapper::new();
 527                    Ok(map_cloud_completion_events(
 528                        Box::pin(response_lines(response, includes_status_messages)),
 529                        &provider_name,
 530                        move |event| mapper.map_event(event),
 531                    ))
 532                });
 533                async move { Ok(future.await?.boxed()) }.boxed()
 534            }
 535            cloud_llm_client::LanguageModelProvider::OpenAi => {
 536                let http_client = self.http_client.clone();
 537                let token_provider = self.token_provider.clone();
 538                let effort = request
 539                    .thinking_effort
 540                    .as_ref()
 541                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 542
 543                let mut request = into_open_ai_response(
 544                    request,
 545                    &self.model.id.0,
 546                    self.model.supports_parallel_tool_calls,
 547                    true,
 548                    None,
 549                    None,
 550                );
 551
 552                if enable_thinking && let Some(effort) = effort {
 553                    request.reasoning = Some(open_ai::responses::ReasoningConfig {
 554                        effort,
 555                        summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
 556                    });
 557                }
 558
 559                let auth_context = token_provider.auth_context(cx);
 560                let future = self.request_limiter.stream(async move {
 561                    let PerformLlmCompletionResponse {
 562                        response,
 563                        includes_status_messages,
 564                    } = Self::perform_llm_completion(
 565                        &http_client,
 566                        &*token_provider,
 567                        auth_context,
 568                        app_version,
 569                        CompletionBody {
 570                            thread_id,
 571                            prompt_id,
 572                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 573                            model: request.model.clone(),
 574                            provider_request: serde_json::to_value(&request)
 575                                .map_err(|e| anyhow!(e))?,
 576                        },
 577                    )
 578                    .await?;
 579
 580                    let mut mapper = OpenAiResponseEventMapper::new();
 581                    Ok(map_cloud_completion_events(
 582                        Box::pin(response_lines(response, includes_status_messages)),
 583                        &provider_name,
 584                        move |event| mapper.map_event(event),
 585                    ))
 586                });
 587                async move { Ok(future.await?.boxed()) }.boxed()
 588            }
 589            cloud_llm_client::LanguageModelProvider::XAi => {
 590                let http_client = self.http_client.clone();
 591                let token_provider = self.token_provider.clone();
 592                let request = into_open_ai(
 593                    request,
 594                    &self.model.id.0,
 595                    self.model.supports_parallel_tool_calls,
 596                    false,
 597                    None,
 598                    None,
 599                );
 600                let auth_context = token_provider.auth_context(cx);
 601                let future = self.request_limiter.stream(async move {
 602                    let PerformLlmCompletionResponse {
 603                        response,
 604                        includes_status_messages,
 605                    } = Self::perform_llm_completion(
 606                        &http_client,
 607                        &*token_provider,
 608                        auth_context,
 609                        app_version,
 610                        CompletionBody {
 611                            thread_id,
 612                            prompt_id,
 613                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 614                            model: request.model.clone(),
 615                            provider_request: serde_json::to_value(&request)
 616                                .map_err(|e| anyhow!(e))?,
 617                        },
 618                    )
 619                    .await?;
 620
 621                    let mut mapper = OpenAiEventMapper::new();
 622                    Ok(map_cloud_completion_events(
 623                        Box::pin(response_lines(response, includes_status_messages)),
 624                        &provider_name,
 625                        move |event| mapper.map_event(event),
 626                    ))
 627                });
 628                async move { Ok(future.await?.boxed()) }.boxed()
 629            }
 630            cloud_llm_client::LanguageModelProvider::Google => {
 631                let http_client = self.http_client.clone();
 632                let token_provider = self.token_provider.clone();
 633                let request =
 634                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 635                let auth_context = token_provider.auth_context(cx);
 636                let future = self.request_limiter.stream(async move {
 637                    let PerformLlmCompletionResponse {
 638                        response,
 639                        includes_status_messages,
 640                    } = Self::perform_llm_completion(
 641                        &http_client,
 642                        &*token_provider,
 643                        auth_context,
 644                        app_version,
 645                        CompletionBody {
 646                            thread_id,
 647                            prompt_id,
 648                            provider: cloud_llm_client::LanguageModelProvider::Google,
 649                            model: request.model.model_id.clone(),
 650                            provider_request: serde_json::to_value(&request)
 651                                .map_err(|e| anyhow!(e))?,
 652                        },
 653                    )
 654                    .await?;
 655
 656                    let mut mapper = GoogleEventMapper::new();
 657                    Ok(map_cloud_completion_events(
 658                        Box::pin(response_lines(response, includes_status_messages)),
 659                        &provider_name,
 660                        move |event| mapper.map_event(event),
 661                    ))
 662                });
 663                async move { Ok(future.await?.boxed()) }.boxed()
 664            }
 665        }
 666    }
 667}
 668
 669pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
 670    token_provider: Arc<TP>,
 671    http_client: Arc<HttpClientWithUrl>,
 672    app_version: Option<Version>,
 673    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 674    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 675    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 676    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 677}
 678
 679impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
 680    pub fn new(
 681        token_provider: Arc<TP>,
 682        http_client: Arc<HttpClientWithUrl>,
 683        app_version: Option<Version>,
 684    ) -> Self {
 685        Self {
 686            token_provider,
 687            http_client,
 688            app_version,
 689            models: Vec::new(),
 690            default_model: None,
 691            default_fast_model: None,
 692            recommended_models: Vec::new(),
 693        }
 694    }
 695
 696    pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 697        let http_client = self.http_client.clone();
 698        let token_provider = self.token_provider.clone();
 699        cx.spawn(async move |this, cx| {
 700            let auth_context = token_provider.auth_context(cx);
 701            let response =
 702                Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
 703            this.update(cx, |this, cx| {
 704                this.update_models(response);
 705                cx.notify();
 706            })
 707        })
 708    }
 709
 710    async fn fetch_models_request(
 711        http_client: &HttpClientWithUrl,
 712        token_provider: &TP,
 713        auth_context: TP::AuthContext,
 714    ) -> Result<ListModelsResponse> {
 715        let token = token_provider.acquire_token(auth_context).await?;
 716
 717        let request = http_client::Request::builder()
 718            .method(Method::GET)
 719            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 720            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 721            .header("Authorization", format!("Bearer {token}"))
 722            .body(AsyncBody::empty())?;
 723        let mut response = http_client
 724            .send(request)
 725            .await
 726            .context("failed to send list models request")?;
 727
 728        if response.status().is_success() {
 729            let mut body = String::new();
 730            response.body_mut().read_to_string(&mut body).await?;
 731            Ok(serde_json::from_str(&body)?)
 732        } else {
 733            let mut body = String::new();
 734            response.body_mut().read_to_string(&mut body).await?;
 735            anyhow::bail!(
 736                "error listing models.\nStatus: {:?}\nBody: {body}",
 737                response.status(),
 738            );
 739        }
 740    }
 741
 742    pub fn update_models(&mut self, response: ListModelsResponse) {
 743        let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
 744
 745        self.default_model = models
 746            .iter()
 747            .find(|model| {
 748                response
 749                    .default_model
 750                    .as_ref()
 751                    .is_some_and(|default_model_id| &model.id == default_model_id)
 752            })
 753            .cloned();
 754        self.default_fast_model = models
 755            .iter()
 756            .find(|model| {
 757                response
 758                    .default_fast_model
 759                    .as_ref()
 760                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 761            })
 762            .cloned();
 763        self.recommended_models = response
 764            .recommended_models
 765            .iter()
 766            .filter_map(|id| models.iter().find(|model| &model.id == id))
 767            .cloned()
 768            .collect();
 769        self.models = models;
 770    }
 771
 772    pub fn create_model(
 773        &self,
 774        model: &Arc<cloud_llm_client::LanguageModel>,
 775    ) -> Arc<dyn LanguageModel> {
 776        Arc::new(CloudLanguageModel::<TP> {
 777            id: LanguageModelId::from(model.id.0.to_string()),
 778            model: model.clone(),
 779            token_provider: self.token_provider.clone(),
 780            http_client: self.http_client.clone(),
 781            app_version: self.app_version.clone(),
 782            request_limiter: RateLimiter::new(4),
 783        })
 784    }
 785
 786    pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
 787        &self.models
 788    }
 789
 790    pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
 791        self.default_model.as_ref()
 792    }
 793
 794    pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
 795        self.default_fast_model.as_ref()
 796    }
 797
 798    pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
 799        &self.recommended_models
 800    }
 801}
 802
 803pub fn map_cloud_completion_events<T, F>(
 804    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 805    provider: &LanguageModelProviderName,
 806    mut map_callback: F,
 807) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 808where
 809    T: DeserializeOwned + 'static,
 810    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 811        + Send
 812        + 'static,
 813{
 814    let provider = provider.clone();
 815    let mut stream = stream.fuse();
 816
 817    let mut saw_stream_ended = false;
 818
 819    let mut done = false;
 820    let mut pending = VecDeque::new();
 821
 822    stream::poll_fn(move |cx| {
 823        loop {
 824            if let Some(item) = pending.pop_front() {
 825                return Poll::Ready(Some(item));
 826            }
 827
 828            if done {
 829                return Poll::Ready(None);
 830            }
 831
 832            match stream.poll_next_unpin(cx) {
 833                Poll::Ready(Some(event)) => {
 834                    let items = match event {
 835                        Err(error) => {
 836                            vec![Err(LanguageModelCompletionError::from(error))]
 837                        }
 838                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
 839                            saw_stream_ended = true;
 840                            vec![]
 841                        }
 842                        Ok(CompletionEvent::Status(status)) => {
 843                            LanguageModelCompletionEvent::from_completion_request_status(
 844                                status,
 845                                provider.clone(),
 846                            )
 847                            .transpose()
 848                            .map(|event| vec![event])
 849                            .unwrap_or_default()
 850                        }
 851                        Ok(CompletionEvent::Event(event)) => map_callback(event),
 852                    };
 853                    pending.extend(items);
 854                }
 855                Poll::Ready(None) => {
 856                    done = true;
 857
 858                    if !saw_stream_ended {
 859                        return Poll::Ready(Some(Err(
 860                            LanguageModelCompletionError::StreamEndedUnexpectedly {
 861                                provider: provider.clone(),
 862                            },
 863                        )));
 864                    }
 865                }
 866                Poll::Pending => return Poll::Pending,
 867            }
 868        }
 869    })
 870    .boxed()
 871}
 872
 873pub fn provider_name(
 874    provider: &cloud_llm_client::LanguageModelProvider,
 875) -> LanguageModelProviderName {
 876    match provider {
 877        cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
 878        cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
 879        cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
 880        cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
 881    }
 882}
 883
 884pub fn response_lines<T: DeserializeOwned>(
 885    response: Response<AsyncBody>,
 886    includes_status_messages: bool,
 887) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 888    futures::stream::try_unfold(
 889        (String::new(), BufReader::new(response.into_body())),
 890        move |(mut line, mut body)| async move {
 891            match body.read_line(&mut line).await {
 892                Ok(0) => Ok(None),
 893                Ok(_) => {
 894                    let event = if includes_status_messages {
 895                        serde_json::from_str::<CompletionEvent<T>>(&line)?
 896                    } else {
 897                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
 898                    };
 899
 900                    line.clear();
 901                    Ok(Some((event, (line, body))))
 902                }
 903                Err(e) => Err(e.into()),
 904            }
 905        },
 906    )
 907}
 908
 909#[cfg(test)]
 910mod tests {
 911    use super::*;
 912    use http_client::http::{HeaderMap, StatusCode};
 913    use language_model::LanguageModelCompletionError;
 914
 915    #[test]
 916    fn test_api_error_conversion_with_upstream_http_error() {
 917        // upstream_http_error with 503 status should become ServerOverloaded
 918        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
 919
 920        let api_error = ApiError {
 921            status: StatusCode::INTERNAL_SERVER_ERROR,
 922            body: error_body.to_string(),
 923            headers: HeaderMap::new(),
 924        };
 925
 926        let completion_error: LanguageModelCompletionError = api_error.into();
 927
 928        match completion_error {
 929            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
 930                assert_eq!(
 931                    message,
 932                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
 933                );
 934            }
 935            _ => panic!(
 936                "Expected UpstreamProviderError for upstream 503, got: {:?}",
 937                completion_error
 938            ),
 939        }
 940
 941        // upstream_http_error with 500 status should become ApiInternalServerError
 942        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
 943
 944        let api_error = ApiError {
 945            status: StatusCode::INTERNAL_SERVER_ERROR,
 946            body: error_body.to_string(),
 947            headers: HeaderMap::new(),
 948        };
 949
 950        let completion_error: LanguageModelCompletionError = api_error.into();
 951
 952        match completion_error {
 953            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
 954                assert_eq!(
 955                    message,
 956                    "Received an error from the OpenAI API: internal server error"
 957                );
 958            }
 959            _ => panic!(
 960                "Expected UpstreamProviderError for upstream 500, got: {:?}",
 961                completion_error
 962            ),
 963        }
 964
 965        // upstream_http_error with 429 status should become RateLimitExceeded
 966        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
 967
 968        let api_error = ApiError {
 969            status: StatusCode::INTERNAL_SERVER_ERROR,
 970            body: error_body.to_string(),
 971            headers: HeaderMap::new(),
 972        };
 973
 974        let completion_error: LanguageModelCompletionError = api_error.into();
 975
 976        match completion_error {
 977            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
 978                assert_eq!(
 979                    message,
 980                    "Received an error from the Google API: rate limit exceeded"
 981                );
 982            }
 983            _ => panic!(
 984                "Expected UpstreamProviderError for upstream 429, got: {:?}",
 985                completion_error
 986            ),
 987        }
 988
 989        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
 990        let error_body = "Regular internal server error";
 991
 992        let api_error = ApiError {
 993            status: StatusCode::INTERNAL_SERVER_ERROR,
 994            body: error_body.to_string(),
 995            headers: HeaderMap::new(),
 996        };
 997
 998        let completion_error: LanguageModelCompletionError = api_error.into();
 999
1000        match completion_error {
1001            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1002                assert_eq!(provider, PROVIDER_NAME);
1003                assert_eq!(message, "Regular internal server error");
1004            }
1005            _ => panic!(
1006                "Expected ApiInternalServerError for regular 500, got: {:?}",
1007                completion_error
1008            ),
1009        }
1010
1011        // upstream_http_429 format should be converted to UpstreamProviderError
1012        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1013
1014        let api_error = ApiError {
1015            status: StatusCode::INTERNAL_SERVER_ERROR,
1016            body: error_body.to_string(),
1017            headers: HeaderMap::new(),
1018        };
1019
1020        let completion_error: LanguageModelCompletionError = api_error.into();
1021
1022        match completion_error {
1023            LanguageModelCompletionError::UpstreamProviderError {
1024                message,
1025                status,
1026                retry_after,
1027            } => {
1028                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1029                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1030                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1031            }
1032            _ => panic!(
1033                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1034                completion_error
1035            ),
1036        }
1037
1038        // Invalid JSON in error body should fall back to regular error handling
1039        let error_body = "Not JSON at all";
1040
1041        let api_error = ApiError {
1042            status: StatusCode::INTERNAL_SERVER_ERROR,
1043            body: error_body.to_string(),
1044            headers: HeaderMap::new(),
1045        };
1046
1047        let completion_error: LanguageModelCompletionError = api_error.into();
1048
1049        match completion_error {
1050            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1051                assert_eq!(provider, PROVIDER_NAME);
1052            }
1053            _ => panic!(
1054                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1055                completion_error
1056            ),
1057        }
1058    }
1059}