language_models_cloud.rs

   1use anthropic::AnthropicModelMode;
   2use anyhow::{Context as _, Result, anyhow};
   3use cloud_llm_client::{
   4    CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
   5    CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
   6    CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse,
   7    OUTDATED_LLM_TOKEN_HEADER_NAME, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
   8    ZED_VERSION_HEADER_NAME,
   9};
  10use futures::{
  11    AsyncBufReadExt, FutureExt, Stream, StreamExt,
  12    future::BoxFuture,
  13    stream::{self, BoxStream},
  14};
  15use google_ai::GoogleModelMode;
  16use gpui::{App, AppContext, AsyncApp, Context, Task};
  17use http_client::http::{HeaderMap, HeaderValue};
  18use http_client::{
  19    AsyncBody, HttpClient, HttpClientWithUrl, HttpRequestExt, Method, Response, StatusCode,
  20};
  21use language_model::{
  22    ANTHROPIC_PROVIDER_ID, ANTHROPIC_PROVIDER_NAME, GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME,
  23    LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionError,
  24    LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName,
  25    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest,
  26    LanguageModelToolChoice, LanguageModelToolSchemaFormat, OPEN_AI_PROVIDER_ID,
  27    OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
  28    ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
  29};
  30
  31use schemars::JsonSchema;
  32use semver::Version;
  33use serde::{Deserialize, Serialize, de::DeserializeOwned};
  34use smol::io::{AsyncReadExt, BufReader};
  35use std::collections::VecDeque;
  36use std::pin::Pin;
  37use std::str::FromStr;
  38use std::sync::Arc;
  39use std::task::Poll;
  40use std::time::Duration;
  41use thiserror::Error;
  42
  43use anthropic::completion::{
  44    AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
  45};
  46use google_ai::completion::{GoogleEventMapper, into_google};
  47use open_ai::completion::{
  48    OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
  49    into_open_ai_response,
  50};
  51use x_ai::completion::count_xai_tokens;
  52
  53const PROVIDER_ID: LanguageModelProviderId = ZED_CLOUD_PROVIDER_ID;
  54const PROVIDER_NAME: LanguageModelProviderName = ZED_CLOUD_PROVIDER_NAME;
  55
  56/// Trait for acquiring and refreshing LLM authentication tokens.
  57pub trait CloudLlmTokenProvider: Send + Sync {
  58    type AuthContext: Clone + Send + 'static;
  59
  60    fn auth_context(&self, cx: &impl AppContext) -> Self::AuthContext;
  61    fn acquire_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
  62    fn refresh_token(&self, auth_context: Self::AuthContext) -> BoxFuture<'static, Result<String>>;
  63}
  64
  65#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  66#[serde(tag = "type", rename_all = "lowercase")]
  67pub enum ModelMode {
  68    #[default]
  69    Default,
  70    Thinking {
  71        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  72        budget_tokens: Option<u32>,
  73    },
  74}
  75
  76impl From<ModelMode> for AnthropicModelMode {
  77    fn from(value: ModelMode) -> Self {
  78        match value {
  79            ModelMode::Default => AnthropicModelMode::Default,
  80            ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
  81        }
  82    }
  83}
  84
  85pub struct CloudLanguageModel<TP: CloudLlmTokenProvider> {
  86    pub id: LanguageModelId,
  87    pub model: Arc<cloud_llm_client::LanguageModel>,
  88    pub token_provider: Arc<TP>,
  89    pub http_client: Arc<HttpClientWithUrl>,
  90    pub app_version: Option<Version>,
  91    pub request_limiter: RateLimiter,
  92}
  93
  94pub struct PerformLlmCompletionResponse {
  95    pub response: Response<AsyncBody>,
  96    pub includes_status_messages: bool,
  97}
  98
  99impl<TP: CloudLlmTokenProvider> CloudLanguageModel<TP> {
 100    pub async fn perform_llm_completion(
 101        http_client: &HttpClientWithUrl,
 102        token_provider: &TP,
 103        auth_context: TP::AuthContext,
 104        app_version: Option<Version>,
 105        body: CompletionBody,
 106    ) -> Result<PerformLlmCompletionResponse> {
 107        let mut token = token_provider.acquire_token(auth_context.clone()).await?;
 108        let mut refreshed_token = false;
 109
 110        loop {
 111            let request = http_client::Request::builder()
 112                .method(Method::POST)
 113                .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
 114                .when_some(app_version.as_ref(), |builder, app_version| {
 115                    builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 116                })
 117                .header("Content-Type", "application/json")
 118                .header("Authorization", format!("Bearer {token}"))
 119                .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
 120                .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
 121                .body(serde_json::to_string(&body)?.into())?;
 122
 123            let mut response = http_client.send(request).await?;
 124            let status = response.status();
 125            if status.is_success() {
 126                let includes_status_messages = response
 127                    .headers()
 128                    .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
 129                    .is_some();
 130
 131                return Ok(PerformLlmCompletionResponse {
 132                    response,
 133                    includes_status_messages,
 134                });
 135            }
 136
 137            if !refreshed_token && needs_llm_token_refresh(&response) {
 138                token = token_provider.refresh_token(auth_context.clone()).await?;
 139                refreshed_token = true;
 140                continue;
 141            }
 142
 143            if status == StatusCode::PAYMENT_REQUIRED {
 144                return Err(anyhow!(PaymentRequiredError));
 145            }
 146
 147            let mut body = String::new();
 148            let headers = response.headers().clone();
 149            response.body_mut().read_to_string(&mut body).await?;
 150            return Err(anyhow!(ApiError {
 151                status,
 152                body,
 153                headers
 154            }));
 155        }
 156    }
 157}
 158
 159fn needs_llm_token_refresh(response: &Response<AsyncBody>) -> bool {
 160    response
 161        .headers()
 162        .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 163        .is_some()
 164        || response
 165            .headers()
 166            .get(OUTDATED_LLM_TOKEN_HEADER_NAME)
 167            .is_some()
 168}
 169
 170#[derive(Debug, Error)]
 171#[error("cloud language model request failed with status {status}: {body}")]
 172struct ApiError {
 173    status: StatusCode,
 174    body: String,
 175    headers: HeaderMap<HeaderValue>,
 176}
 177
 178/// Represents error responses from Zed's cloud API.
 179///
 180/// Example JSON for an upstream HTTP error:
 181/// ```json
 182/// {
 183///   "code": "upstream_http_error",
 184///   "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
 185///   "upstream_status": 503
 186/// }
 187/// ```
 188#[derive(Debug, serde::Deserialize)]
 189struct CloudApiError {
 190    code: String,
 191    message: String,
 192    #[serde(default)]
 193    #[serde(deserialize_with = "deserialize_optional_status_code")]
 194    upstream_status: Option<StatusCode>,
 195    #[serde(default)]
 196    retry_after: Option<f64>,
 197}
 198
 199fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
 200where
 201    D: serde::Deserializer<'de>,
 202{
 203    let opt: Option<u16> = Option::deserialize(deserializer)?;
 204    Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
 205}
 206
 207impl From<ApiError> for LanguageModelCompletionError {
 208    fn from(error: ApiError) -> Self {
 209        if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
 210            if cloud_error.code.starts_with("upstream_http_") {
 211                let status = if let Some(status) = cloud_error.upstream_status {
 212                    status
 213                } else if cloud_error.code.ends_with("_error") {
 214                    error.status
 215                } else {
 216                    // If there's a status code in the code string (e.g. "upstream_http_429")
 217                    // then use that; otherwise, see if the JSON contains a status code.
 218                    cloud_error
 219                        .code
 220                        .strip_prefix("upstream_http_")
 221                        .and_then(|code_str| code_str.parse::<u16>().ok())
 222                        .and_then(|code| StatusCode::from_u16(code).ok())
 223                        .unwrap_or(error.status)
 224                };
 225
 226                return LanguageModelCompletionError::UpstreamProviderError {
 227                    message: cloud_error.message,
 228                    status,
 229                    retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
 230                };
 231            }
 232
 233            return LanguageModelCompletionError::from_http_status(
 234                PROVIDER_NAME,
 235                error.status,
 236                cloud_error.message,
 237                None,
 238            );
 239        }
 240
 241        let retry_after = None;
 242        LanguageModelCompletionError::from_http_status(
 243            PROVIDER_NAME,
 244            error.status,
 245            error.body,
 246            retry_after,
 247        )
 248    }
 249}
 250
 251impl<TP: CloudLlmTokenProvider + 'static> LanguageModel for CloudLanguageModel<TP> {
 252    fn id(&self) -> LanguageModelId {
 253        self.id.clone()
 254    }
 255
 256    fn name(&self) -> LanguageModelName {
 257        LanguageModelName::from(self.model.display_name.clone())
 258    }
 259
 260    fn provider_id(&self) -> LanguageModelProviderId {
 261        PROVIDER_ID
 262    }
 263
 264    fn provider_name(&self) -> LanguageModelProviderName {
 265        PROVIDER_NAME
 266    }
 267
 268    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 269        use cloud_llm_client::LanguageModelProvider::*;
 270        match self.model.provider {
 271            Anthropic => ANTHROPIC_PROVIDER_ID,
 272            OpenAi => OPEN_AI_PROVIDER_ID,
 273            Google => GOOGLE_PROVIDER_ID,
 274            XAi => X_AI_PROVIDER_ID,
 275        }
 276    }
 277
 278    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 279        use cloud_llm_client::LanguageModelProvider::*;
 280        match self.model.provider {
 281            Anthropic => ANTHROPIC_PROVIDER_NAME,
 282            OpenAi => OPEN_AI_PROVIDER_NAME,
 283            Google => GOOGLE_PROVIDER_NAME,
 284            XAi => X_AI_PROVIDER_NAME,
 285        }
 286    }
 287
 288    fn is_latest(&self) -> bool {
 289        self.model.is_latest
 290    }
 291
 292    fn supports_tools(&self) -> bool {
 293        self.model.supports_tools
 294    }
 295
 296    fn supports_images(&self) -> bool {
 297        self.model.supports_images
 298    }
 299
 300    fn supports_thinking(&self) -> bool {
 301        self.model.supports_thinking
 302    }
 303
 304    fn supports_fast_mode(&self) -> bool {
 305        self.model.supports_fast_mode
 306    }
 307
 308    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 309        self.model
 310            .supported_effort_levels
 311            .iter()
 312            .map(|effort_level| LanguageModelEffortLevel {
 313                name: effort_level.name.clone().into(),
 314                value: effort_level.value.clone().into(),
 315                is_default: effort_level.is_default.unwrap_or(false),
 316            })
 317            .collect()
 318    }
 319
 320    fn supports_streaming_tools(&self) -> bool {
 321        self.model.supports_streaming_tools
 322    }
 323
 324    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 325        match choice {
 326            LanguageModelToolChoice::Auto
 327            | LanguageModelToolChoice::Any
 328            | LanguageModelToolChoice::None => true,
 329        }
 330    }
 331
 332    fn supports_split_token_display(&self) -> bool {
 333        use cloud_llm_client::LanguageModelProvider::*;
 334        matches!(self.model.provider, OpenAi | XAi)
 335    }
 336
 337    fn telemetry_id(&self) -> String {
 338        format!("zed.dev/{}", self.model.id)
 339    }
 340
 341    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 342        match self.model.provider {
 343            cloud_llm_client::LanguageModelProvider::Anthropic
 344            | cloud_llm_client::LanguageModelProvider::OpenAi => {
 345                LanguageModelToolSchemaFormat::JsonSchema
 346            }
 347            cloud_llm_client::LanguageModelProvider::Google
 348            | cloud_llm_client::LanguageModelProvider::XAi => {
 349                LanguageModelToolSchemaFormat::JsonSchemaSubset
 350            }
 351        }
 352    }
 353
 354    fn max_token_count(&self) -> u64 {
 355        self.model.max_token_count as u64
 356    }
 357
 358    fn max_output_tokens(&self) -> Option<u64> {
 359        Some(self.model.max_output_tokens as u64)
 360    }
 361
 362    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 363        match &self.model.provider {
 364            cloud_llm_client::LanguageModelProvider::Anthropic => {
 365                Some(LanguageModelCacheConfiguration {
 366                    min_total_token: 2_048,
 367                    should_speculate: true,
 368                    max_cache_anchors: 4,
 369                })
 370            }
 371            cloud_llm_client::LanguageModelProvider::OpenAi
 372            | cloud_llm_client::LanguageModelProvider::XAi
 373            | cloud_llm_client::LanguageModelProvider::Google => None,
 374        }
 375    }
 376
 377    fn count_tokens(
 378        &self,
 379        request: LanguageModelRequest,
 380        cx: &App,
 381    ) -> BoxFuture<'static, Result<u64>> {
 382        match self.model.provider {
 383            cloud_llm_client::LanguageModelProvider::Anthropic => cx
 384                .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
 385                .boxed(),
 386            cloud_llm_client::LanguageModelProvider::OpenAi => {
 387                let model = match open_ai::Model::from_id(&self.model.id.0) {
 388                    Ok(model) => model,
 389                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 390                };
 391                cx.background_spawn(async move { count_open_ai_tokens(request, model) })
 392                    .boxed()
 393            }
 394            cloud_llm_client::LanguageModelProvider::XAi => {
 395                let model = match x_ai::Model::from_id(&self.model.id.0) {
 396                    Ok(model) => model,
 397                    Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
 398                };
 399                cx.background_spawn(async move { count_xai_tokens(request, model) })
 400                    .boxed()
 401            }
 402            cloud_llm_client::LanguageModelProvider::Google => {
 403                let http_client = self.http_client.clone();
 404                let token_provider = self.token_provider.clone();
 405                let model_id = self.model.id.to_string();
 406                let generate_content_request =
 407                    into_google(request, model_id.clone(), GoogleModelMode::Default);
 408                let auth_context = token_provider.auth_context(cx);
 409                async move {
 410                    let token = token_provider.acquire_token(auth_context).await?;
 411
 412                    let request_body = CountTokensBody {
 413                        provider: cloud_llm_client::LanguageModelProvider::Google,
 414                        model: model_id,
 415                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
 416                            generate_content_request,
 417                        })?,
 418                    };
 419                    let request = http_client::Request::builder()
 420                        .method(Method::POST)
 421                        .uri(
 422                            http_client
 423                                .build_zed_llm_url("/count_tokens", &[])?
 424                                .as_ref(),
 425                        )
 426                        .header("Content-Type", "application/json")
 427                        .header("Authorization", format!("Bearer {token}"))
 428                        .body(serde_json::to_string(&request_body)?.into())?;
 429                    let mut response = http_client.send(request).await?;
 430                    let status = response.status();
 431                    let headers = response.headers().clone();
 432                    let mut response_body = String::new();
 433                    response
 434                        .body_mut()
 435                        .read_to_string(&mut response_body)
 436                        .await?;
 437
 438                    if status.is_success() {
 439                        let response_body: CountTokensResponse =
 440                            serde_json::from_str(&response_body)?;
 441
 442                        Ok(response_body.tokens as u64)
 443                    } else {
 444                        Err(anyhow!(ApiError {
 445                            status,
 446                            body: response_body,
 447                            headers
 448                        }))
 449                    }
 450                }
 451                .boxed()
 452            }
 453        }
 454    }
 455
 456    fn stream_completion(
 457        &self,
 458        request: LanguageModelRequest,
 459        cx: &AsyncApp,
 460    ) -> BoxFuture<
 461        'static,
 462        Result<
 463            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 464            LanguageModelCompletionError,
 465        >,
 466    > {
 467        let thread_id = request.thread_id.clone();
 468        let prompt_id = request.prompt_id.clone();
 469        let app_version = self.app_version.clone();
 470        let thinking_allowed = request.thinking_allowed;
 471        let enable_thinking = thinking_allowed && self.model.supports_thinking;
 472        let provider_name = provider_name(&self.model.provider);
 473        match self.model.provider {
 474            cloud_llm_client::LanguageModelProvider::Anthropic => {
 475                let effort = request
 476                    .thinking_effort
 477                    .as_ref()
 478                    .and_then(|effort| anthropic::Effort::from_str(effort).ok());
 479
 480                let mut request = into_anthropic(
 481                    request,
 482                    self.model.id.to_string(),
 483                    1.0,
 484                    self.model.max_output_tokens as u64,
 485                    if enable_thinking {
 486                        AnthropicModelMode::Thinking {
 487                            budget_tokens: Some(4_096),
 488                        }
 489                    } else {
 490                        AnthropicModelMode::Default
 491                    },
 492                );
 493
 494                if enable_thinking && effort.is_some() {
 495                    request.thinking = Some(anthropic::Thinking::Adaptive);
 496                    request.output_config = Some(anthropic::OutputConfig { effort });
 497                }
 498
 499                if !self.model.supports_fast_mode {
 500                    request.speed = None;
 501                }
 502
 503                let http_client = self.http_client.clone();
 504                let token_provider = self.token_provider.clone();
 505                let auth_context = token_provider.auth_context(cx);
 506                let future = self.request_limiter.stream(async move {
 507                    let PerformLlmCompletionResponse {
 508                        response,
 509                        includes_status_messages,
 510                    } = Self::perform_llm_completion(
 511                        &http_client,
 512                        &*token_provider,
 513                        auth_context,
 514                        app_version,
 515                        CompletionBody {
 516                            thread_id,
 517                            prompt_id,
 518                            provider: cloud_llm_client::LanguageModelProvider::Anthropic,
 519                            model: request.model.clone(),
 520                            provider_request: serde_json::to_value(&request)
 521                                .map_err(|e| anyhow!(e))?,
 522                        },
 523                    )
 524                    .await
 525                    .map_err(|err| match err.downcast::<ApiError>() {
 526                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
 527                        Err(err) => anyhow!(err),
 528                    })?;
 529
 530                    let mut mapper = AnthropicEventMapper::new();
 531                    Ok(map_cloud_completion_events(
 532                        Box::pin(response_lines(response, includes_status_messages)),
 533                        &provider_name,
 534                        move |event| mapper.map_event(event),
 535                    ))
 536                });
 537                async move { Ok(future.await?.boxed()) }.boxed()
 538            }
 539            cloud_llm_client::LanguageModelProvider::OpenAi => {
 540                let http_client = self.http_client.clone();
 541                let token_provider = self.token_provider.clone();
 542                let effort = request
 543                    .thinking_effort
 544                    .as_ref()
 545                    .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
 546
 547                let mut request = into_open_ai_response(
 548                    request,
 549                    &self.model.id.0,
 550                    self.model.supports_parallel_tool_calls,
 551                    true,
 552                    None,
 553                    None,
 554                );
 555
 556                if enable_thinking && let Some(effort) = effort {
 557                    request.reasoning = Some(open_ai::responses::ReasoningConfig {
 558                        effort,
 559                        summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
 560                    });
 561                }
 562
 563                let auth_context = token_provider.auth_context(cx);
 564                let future = self.request_limiter.stream(async move {
 565                    let PerformLlmCompletionResponse {
 566                        response,
 567                        includes_status_messages,
 568                    } = Self::perform_llm_completion(
 569                        &http_client,
 570                        &*token_provider,
 571                        auth_context,
 572                        app_version,
 573                        CompletionBody {
 574                            thread_id,
 575                            prompt_id,
 576                            provider: cloud_llm_client::LanguageModelProvider::OpenAi,
 577                            model: request.model.clone(),
 578                            provider_request: serde_json::to_value(&request)
 579                                .map_err(|e| anyhow!(e))?,
 580                        },
 581                    )
 582                    .await?;
 583
 584                    let mut mapper = OpenAiResponseEventMapper::new();
 585                    Ok(map_cloud_completion_events(
 586                        Box::pin(response_lines(response, includes_status_messages)),
 587                        &provider_name,
 588                        move |event| mapper.map_event(event),
 589                    ))
 590                });
 591                async move { Ok(future.await?.boxed()) }.boxed()
 592            }
 593            cloud_llm_client::LanguageModelProvider::XAi => {
 594                let http_client = self.http_client.clone();
 595                let token_provider = self.token_provider.clone();
 596                let request = into_open_ai(
 597                    request,
 598                    &self.model.id.0,
 599                    self.model.supports_parallel_tool_calls,
 600                    false,
 601                    None,
 602                    None,
 603                );
 604                let auth_context = token_provider.auth_context(cx);
 605                let future = self.request_limiter.stream(async move {
 606                    let PerformLlmCompletionResponse {
 607                        response,
 608                        includes_status_messages,
 609                    } = Self::perform_llm_completion(
 610                        &http_client,
 611                        &*token_provider,
 612                        auth_context,
 613                        app_version,
 614                        CompletionBody {
 615                            thread_id,
 616                            prompt_id,
 617                            provider: cloud_llm_client::LanguageModelProvider::XAi,
 618                            model: request.model.clone(),
 619                            provider_request: serde_json::to_value(&request)
 620                                .map_err(|e| anyhow!(e))?,
 621                        },
 622                    )
 623                    .await?;
 624
 625                    let mut mapper = OpenAiEventMapper::new();
 626                    Ok(map_cloud_completion_events(
 627                        Box::pin(response_lines(response, includes_status_messages)),
 628                        &provider_name,
 629                        move |event| mapper.map_event(event),
 630                    ))
 631                });
 632                async move { Ok(future.await?.boxed()) }.boxed()
 633            }
 634            cloud_llm_client::LanguageModelProvider::Google => {
 635                let http_client = self.http_client.clone();
 636                let token_provider = self.token_provider.clone();
 637                let request =
 638                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
 639                let auth_context = token_provider.auth_context(cx);
 640                let future = self.request_limiter.stream(async move {
 641                    let PerformLlmCompletionResponse {
 642                        response,
 643                        includes_status_messages,
 644                    } = Self::perform_llm_completion(
 645                        &http_client,
 646                        &*token_provider,
 647                        auth_context,
 648                        app_version,
 649                        CompletionBody {
 650                            thread_id,
 651                            prompt_id,
 652                            provider: cloud_llm_client::LanguageModelProvider::Google,
 653                            model: request.model.model_id.clone(),
 654                            provider_request: serde_json::to_value(&request)
 655                                .map_err(|e| anyhow!(e))?,
 656                        },
 657                    )
 658                    .await?;
 659
 660                    let mut mapper = GoogleEventMapper::new();
 661                    Ok(map_cloud_completion_events(
 662                        Box::pin(response_lines(response, includes_status_messages)),
 663                        &provider_name,
 664                        move |event| mapper.map_event(event),
 665                    ))
 666                });
 667                async move { Ok(future.await?.boxed()) }.boxed()
 668            }
 669        }
 670    }
 671}
 672
 673pub struct CloudModelProvider<TP: CloudLlmTokenProvider> {
 674    token_provider: Arc<TP>,
 675    http_client: Arc<HttpClientWithUrl>,
 676    app_version: Option<Version>,
 677    models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 678    default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 679    default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
 680    recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
 681}
 682
 683impl<TP: CloudLlmTokenProvider + 'static> CloudModelProvider<TP> {
 684    pub fn new(
 685        token_provider: Arc<TP>,
 686        http_client: Arc<HttpClientWithUrl>,
 687        app_version: Option<Version>,
 688    ) -> Self {
 689        Self {
 690            token_provider,
 691            http_client,
 692            app_version,
 693            models: Vec::new(),
 694            default_model: None,
 695            default_fast_model: None,
 696            recommended_models: Vec::new(),
 697        }
 698    }
 699
 700    pub fn refresh_models(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 701        let http_client = self.http_client.clone();
 702        let token_provider = self.token_provider.clone();
 703        cx.spawn(async move |this, cx| {
 704            let auth_context = token_provider.auth_context(cx);
 705            let response =
 706                Self::fetch_models_request(&http_client, &*token_provider, auth_context).await?;
 707            this.update(cx, |this, cx| {
 708                this.update_models(response);
 709                cx.notify();
 710            })
 711        })
 712    }
 713
 714    async fn fetch_models_request(
 715        http_client: &HttpClientWithUrl,
 716        token_provider: &TP,
 717        auth_context: TP::AuthContext,
 718    ) -> Result<ListModelsResponse> {
 719        let token = token_provider.acquire_token(auth_context).await?;
 720
 721        let request = http_client::Request::builder()
 722            .method(Method::GET)
 723            .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
 724            .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
 725            .header("Authorization", format!("Bearer {token}"))
 726            .body(AsyncBody::empty())?;
 727        let mut response = http_client
 728            .send(request)
 729            .await
 730            .context("failed to send list models request")?;
 731
 732        if response.status().is_success() {
 733            let mut body = String::new();
 734            response.body_mut().read_to_string(&mut body).await?;
 735            Ok(serde_json::from_str(&body)?)
 736        } else {
 737            let mut body = String::new();
 738            response.body_mut().read_to_string(&mut body).await?;
 739            anyhow::bail!(
 740                "error listing models.\nStatus: {:?}\nBody: {body}",
 741                response.status(),
 742            );
 743        }
 744    }
 745
 746    pub fn update_models(&mut self, response: ListModelsResponse) {
 747        let models: Vec<_> = response.models.into_iter().map(Arc::new).collect();
 748
 749        self.default_model = models
 750            .iter()
 751            .find(|model| {
 752                response
 753                    .default_model
 754                    .as_ref()
 755                    .is_some_and(|default_model_id| &model.id == default_model_id)
 756            })
 757            .cloned();
 758        self.default_fast_model = models
 759            .iter()
 760            .find(|model| {
 761                response
 762                    .default_fast_model
 763                    .as_ref()
 764                    .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
 765            })
 766            .cloned();
 767        self.recommended_models = response
 768            .recommended_models
 769            .iter()
 770            .filter_map(|id| models.iter().find(|model| &model.id == id))
 771            .cloned()
 772            .collect();
 773        self.models = models;
 774    }
 775
 776    pub fn create_model(
 777        &self,
 778        model: &Arc<cloud_llm_client::LanguageModel>,
 779    ) -> Arc<dyn LanguageModel> {
 780        Arc::new(CloudLanguageModel::<TP> {
 781            id: LanguageModelId::from(model.id.0.to_string()),
 782            model: model.clone(),
 783            token_provider: self.token_provider.clone(),
 784            http_client: self.http_client.clone(),
 785            app_version: self.app_version.clone(),
 786            request_limiter: RateLimiter::new(4),
 787        })
 788    }
 789
 790    pub fn models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
 791        &self.models
 792    }
 793
 794    pub fn default_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
 795        self.default_model.as_ref()
 796    }
 797
 798    pub fn default_fast_model(&self) -> Option<&Arc<cloud_llm_client::LanguageModel>> {
 799        self.default_fast_model.as_ref()
 800    }
 801
 802    pub fn recommended_models(&self) -> &[Arc<cloud_llm_client::LanguageModel>] {
 803        &self.recommended_models
 804    }
 805}
 806
 807pub fn map_cloud_completion_events<T, F>(
 808    stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
 809    provider: &LanguageModelProviderName,
 810    mut map_callback: F,
 811) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 812where
 813    T: DeserializeOwned + 'static,
 814    F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 815        + Send
 816        + 'static,
 817{
 818    let provider = provider.clone();
 819    let mut stream = stream.fuse();
 820
 821    let mut saw_stream_ended = false;
 822
 823    let mut done = false;
 824    let mut pending = VecDeque::new();
 825
 826    stream::poll_fn(move |cx| {
 827        loop {
 828            if let Some(item) = pending.pop_front() {
 829                return Poll::Ready(Some(item));
 830            }
 831
 832            if done {
 833                return Poll::Ready(None);
 834            }
 835
 836            match stream.poll_next_unpin(cx) {
 837                Poll::Ready(Some(event)) => {
 838                    let items = match event {
 839                        Err(error) => {
 840                            vec![Err(LanguageModelCompletionError::from(error))]
 841                        }
 842                        Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
 843                            saw_stream_ended = true;
 844                            vec![]
 845                        }
 846                        Ok(CompletionEvent::Status(status)) => {
 847                            LanguageModelCompletionEvent::from_completion_request_status(
 848                                status,
 849                                provider.clone(),
 850                            )
 851                            .transpose()
 852                            .map(|event| vec![event])
 853                            .unwrap_or_default()
 854                        }
 855                        Ok(CompletionEvent::Event(event)) => map_callback(event),
 856                    };
 857                    pending.extend(items);
 858                }
 859                Poll::Ready(None) => {
 860                    done = true;
 861
 862                    if !saw_stream_ended {
 863                        return Poll::Ready(Some(Err(
 864                            LanguageModelCompletionError::StreamEndedUnexpectedly {
 865                                provider: provider.clone(),
 866                            },
 867                        )));
 868                    }
 869                }
 870                Poll::Pending => return Poll::Pending,
 871            }
 872        }
 873    })
 874    .boxed()
 875}
 876
 877pub fn provider_name(
 878    provider: &cloud_llm_client::LanguageModelProvider,
 879) -> LanguageModelProviderName {
 880    match provider {
 881        cloud_llm_client::LanguageModelProvider::Anthropic => ANTHROPIC_PROVIDER_NAME,
 882        cloud_llm_client::LanguageModelProvider::OpenAi => OPEN_AI_PROVIDER_NAME,
 883        cloud_llm_client::LanguageModelProvider::Google => GOOGLE_PROVIDER_NAME,
 884        cloud_llm_client::LanguageModelProvider::XAi => X_AI_PROVIDER_NAME,
 885    }
 886}
 887
 888pub fn response_lines<T: DeserializeOwned>(
 889    response: Response<AsyncBody>,
 890    includes_status_messages: bool,
 891) -> impl Stream<Item = Result<CompletionEvent<T>>> {
 892    futures::stream::try_unfold(
 893        (String::new(), BufReader::new(response.into_body())),
 894        move |(mut line, mut body)| async move {
 895            match body.read_line(&mut line).await {
 896                Ok(0) => Ok(None),
 897                Ok(_) => {
 898                    let event = if includes_status_messages {
 899                        serde_json::from_str::<CompletionEvent<T>>(&line)?
 900                    } else {
 901                        CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
 902                    };
 903
 904                    line.clear();
 905                    Ok(Some((event, (line, body))))
 906                }
 907                Err(e) => Err(e.into()),
 908            }
 909        },
 910    )
 911}
 912
 913#[cfg(test)]
 914mod tests {
 915    use super::*;
 916    use http_client::http::{HeaderMap, StatusCode};
 917    use language_model::LanguageModelCompletionError;
 918
 919    #[test]
 920    fn test_api_error_conversion_with_upstream_http_error() {
 921        // upstream_http_error with 503 status should become ServerOverloaded
 922        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
 923
 924        let api_error = ApiError {
 925            status: StatusCode::INTERNAL_SERVER_ERROR,
 926            body: error_body.to_string(),
 927            headers: HeaderMap::new(),
 928        };
 929
 930        let completion_error: LanguageModelCompletionError = api_error.into();
 931
 932        match completion_error {
 933            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
 934                assert_eq!(
 935                    message,
 936                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
 937                );
 938            }
 939            _ => panic!(
 940                "Expected UpstreamProviderError for upstream 503, got: {:?}",
 941                completion_error
 942            ),
 943        }
 944
 945        // upstream_http_error with 500 status should become ApiInternalServerError
 946        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
 947
 948        let api_error = ApiError {
 949            status: StatusCode::INTERNAL_SERVER_ERROR,
 950            body: error_body.to_string(),
 951            headers: HeaderMap::new(),
 952        };
 953
 954        let completion_error: LanguageModelCompletionError = api_error.into();
 955
 956        match completion_error {
 957            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
 958                assert_eq!(
 959                    message,
 960                    "Received an error from the OpenAI API: internal server error"
 961                );
 962            }
 963            _ => panic!(
 964                "Expected UpstreamProviderError for upstream 500, got: {:?}",
 965                completion_error
 966            ),
 967        }
 968
 969        // upstream_http_error with 429 status should become RateLimitExceeded
 970        let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
 971
 972        let api_error = ApiError {
 973            status: StatusCode::INTERNAL_SERVER_ERROR,
 974            body: error_body.to_string(),
 975            headers: HeaderMap::new(),
 976        };
 977
 978        let completion_error: LanguageModelCompletionError = api_error.into();
 979
 980        match completion_error {
 981            LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
 982                assert_eq!(
 983                    message,
 984                    "Received an error from the Google API: rate limit exceeded"
 985                );
 986            }
 987            _ => panic!(
 988                "Expected UpstreamProviderError for upstream 429, got: {:?}",
 989                completion_error
 990            ),
 991        }
 992
 993        // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
 994        let error_body = "Regular internal server error";
 995
 996        let api_error = ApiError {
 997            status: StatusCode::INTERNAL_SERVER_ERROR,
 998            body: error_body.to_string(),
 999            headers: HeaderMap::new(),
1000        };
1001
1002        let completion_error: LanguageModelCompletionError = api_error.into();
1003
1004        match completion_error {
1005            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1006                assert_eq!(provider, PROVIDER_NAME);
1007                assert_eq!(message, "Regular internal server error");
1008            }
1009            _ => panic!(
1010                "Expected ApiInternalServerError for regular 500, got: {:?}",
1011                completion_error
1012            ),
1013        }
1014
1015        // upstream_http_429 format should be converted to UpstreamProviderError
1016        let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1017
1018        let api_error = ApiError {
1019            status: StatusCode::INTERNAL_SERVER_ERROR,
1020            body: error_body.to_string(),
1021            headers: HeaderMap::new(),
1022        };
1023
1024        let completion_error: LanguageModelCompletionError = api_error.into();
1025
1026        match completion_error {
1027            LanguageModelCompletionError::UpstreamProviderError {
1028                message,
1029                status,
1030                retry_after,
1031            } => {
1032                assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1033                assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1034                assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1035            }
1036            _ => panic!(
1037                "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1038                completion_error
1039            ),
1040        }
1041
1042        // Invalid JSON in error body should fall back to regular error handling
1043        let error_body = "Not JSON at all";
1044
1045        let api_error = ApiError {
1046            status: StatusCode::INTERNAL_SERVER_ERROR,
1047            body: error_body.to_string(),
1048            headers: HeaderMap::new(),
1049        };
1050
1051        let completion_error: LanguageModelCompletionError = api_error.into();
1052
1053        match completion_error {
1054            LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1055                assert_eq!(provider, PROVIDER_NAME);
1056            }
1057            _ => panic!(
1058                "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1059                completion_error
1060            ),
1061        }
1062    }
1063}