language_model.rs

   1mod api_key;
   2mod model;
   3mod rate_limiter;
   4mod registry;
   5mod request;
   6mod role;
   7mod telemetry;
   8pub mod tool_schema;
   9
  10#[cfg(any(test, feature = "test-support"))]
  11pub mod fake_provider;
  12
  13use anthropic::{AnthropicError, parse_prompt_too_long};
  14use anyhow::{Result, anyhow};
  15use client::Client;
  16use client::UserStore;
  17use cloud_llm_client::CompletionRequestStatus;
  18use futures::FutureExt;
  19use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
  20use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
  21use http_client::{StatusCode, http};
  22use icons::IconName;
  23use open_router::OpenRouterError;
  24use parking_lot::Mutex;
  25use serde::{Deserialize, Serialize};
  26pub use settings::LanguageModelCacheConfiguration;
  27use std::ops::{Add, Sub};
  28use std::str::FromStr;
  29use std::sync::Arc;
  30use std::time::Duration;
  31use std::{fmt, io};
  32use thiserror::Error;
  33use util::serde::is_default;
  34
  35pub use crate::api_key::{ApiKey, ApiKeyState};
  36pub use crate::model::*;
  37pub use crate::rate_limiter::*;
  38pub use crate::registry::*;
  39pub use crate::request::*;
  40pub use crate::role::*;
  41pub use crate::telemetry::*;
  42pub use crate::tool_schema::LanguageModelToolSchemaFormat;
  43pub use zed_env_vars::{EnvVar, env_var};
  44
  45pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
  46    LanguageModelProviderId::new("anthropic");
  47pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
  48    LanguageModelProviderName::new("Anthropic");
  49
  50pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
  51pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
  52    LanguageModelProviderName::new("Google AI");
  53
  54pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
  55pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
  56    LanguageModelProviderName::new("OpenAI");
  57
  58pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
  59pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
  60
  61pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
  62pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
  63    LanguageModelProviderName::new("Zed");
  64
  65pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
  66    init_settings(cx);
  67    RefreshLlmTokenListener::register(client, user_store, cx);
  68}
  69
  70pub fn init_settings(cx: &mut App) {
  71    registry::init(cx);
  72}
  73
  74/// A completion event from a language model.
  75#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
  76pub enum LanguageModelCompletionEvent {
  77    Queued {
  78        position: usize,
  79    },
  80    Started,
  81    Stop(StopReason),
  82    Text(String),
  83    Thinking {
  84        text: String,
  85        signature: Option<String>,
  86    },
  87    RedactedThinking {
  88        data: String,
  89    },
  90    ToolUse(LanguageModelToolUse),
  91    ToolUseJsonParseError {
  92        id: LanguageModelToolUseId,
  93        tool_name: Arc<str>,
  94        raw_input: Arc<str>,
  95        json_parse_error: String,
  96    },
  97    StartMessage {
  98        message_id: String,
  99    },
 100    ReasoningDetails(serde_json::Value),
 101    UsageUpdate(TokenUsage),
 102}
 103
 104impl LanguageModelCompletionEvent {
 105    pub fn from_completion_request_status(
 106        status: CompletionRequestStatus,
 107        upstream_provider: LanguageModelProviderName,
 108    ) -> Result<Option<Self>, LanguageModelCompletionError> {
 109        match status {
 110            CompletionRequestStatus::Queued { position } => {
 111                Ok(Some(LanguageModelCompletionEvent::Queued { position }))
 112            }
 113            CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
 114            CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
 115            CompletionRequestStatus::Failed {
 116                code,
 117                message,
 118                request_id: _,
 119                retry_after,
 120            } => Err(LanguageModelCompletionError::from_cloud_failure(
 121                upstream_provider,
 122                code,
 123                message,
 124                retry_after.map(Duration::from_secs_f64),
 125            )),
 126        }
 127    }
 128}
 129
 130#[derive(Error, Debug)]
 131pub enum LanguageModelCompletionError {
 132    #[error("prompt too large for context window")]
 133    PromptTooLarge { tokens: Option<u64> },
 134    #[error("missing {provider} API key")]
 135    NoApiKey { provider: LanguageModelProviderName },
 136    #[error("{provider}'s API rate limit exceeded")]
 137    RateLimitExceeded {
 138        provider: LanguageModelProviderName,
 139        retry_after: Option<Duration>,
 140    },
 141    #[error("{provider}'s API servers are overloaded right now")]
 142    ServerOverloaded {
 143        provider: LanguageModelProviderName,
 144        retry_after: Option<Duration>,
 145    },
 146    #[error("{provider}'s API server reported an internal server error: {message}")]
 147    ApiInternalServerError {
 148        provider: LanguageModelProviderName,
 149        message: String,
 150    },
 151    #[error("{message}")]
 152    UpstreamProviderError {
 153        message: String,
 154        status: StatusCode,
 155        retry_after: Option<Duration>,
 156    },
 157    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
 158    HttpResponseError {
 159        provider: LanguageModelProviderName,
 160        status_code: StatusCode,
 161        message: String,
 162    },
 163
 164    // Client errors
 165    #[error("invalid request format to {provider}'s API: {message}")]
 166    BadRequestFormat {
 167        provider: LanguageModelProviderName,
 168        message: String,
 169    },
 170    #[error("authentication error with {provider}'s API: {message}")]
 171    AuthenticationError {
 172        provider: LanguageModelProviderName,
 173        message: String,
 174    },
 175    #[error("Permission error with {provider}'s API: {message}")]
 176    PermissionError {
 177        provider: LanguageModelProviderName,
 178        message: String,
 179    },
 180    #[error("language model provider API endpoint not found")]
 181    ApiEndpointNotFound { provider: LanguageModelProviderName },
 182    #[error("I/O error reading response from {provider}'s API")]
 183    ApiReadResponseError {
 184        provider: LanguageModelProviderName,
 185        #[source]
 186        error: io::Error,
 187    },
 188    #[error("error serializing request to {provider} API")]
 189    SerializeRequest {
 190        provider: LanguageModelProviderName,
 191        #[source]
 192        error: serde_json::Error,
 193    },
 194    #[error("error building request body to {provider} API")]
 195    BuildRequestBody {
 196        provider: LanguageModelProviderName,
 197        #[source]
 198        error: http::Error,
 199    },
 200    #[error("error sending HTTP request to {provider} API")]
 201    HttpSend {
 202        provider: LanguageModelProviderName,
 203        #[source]
 204        error: anyhow::Error,
 205    },
 206    #[error("error deserializing {provider} API response")]
 207    DeserializeResponse {
 208        provider: LanguageModelProviderName,
 209        #[source]
 210        error: serde_json::Error,
 211    },
 212
 213    #[error("stream from {provider} ended unexpectedly")]
 214    StreamEndedUnexpectedly { provider: LanguageModelProviderName },
 215
 216    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
 217    #[error(transparent)]
 218    Other(#[from] anyhow::Error),
 219}
 220
 221impl LanguageModelCompletionError {
 222    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
 223        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
 224        let upstream_status = error_json
 225            .get("upstream_status")
 226            .and_then(|v| v.as_u64())
 227            .and_then(|status| u16::try_from(status).ok())
 228            .and_then(|status| StatusCode::from_u16(status).ok())?;
 229        let inner_message = error_json
 230            .get("message")
 231            .and_then(|v| v.as_str())
 232            .unwrap_or(message)
 233            .to_string();
 234        Some((upstream_status, inner_message))
 235    }
 236
 237    pub fn from_cloud_failure(
 238        upstream_provider: LanguageModelProviderName,
 239        code: String,
 240        message: String,
 241        retry_after: Option<Duration>,
 242    ) -> Self {
 243        if let Some(tokens) = parse_prompt_too_long(&message) {
 244            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
 245            // to be reported. This is a temporary workaround to handle this in the case where the
 246            // token limit has been exceeded.
 247            Self::PromptTooLarge {
 248                tokens: Some(tokens),
 249            }
 250        } else if code == "upstream_http_error" {
 251            if let Some((upstream_status, inner_message)) =
 252                Self::parse_upstream_error_json(&message)
 253            {
 254                return Self::from_http_status(
 255                    upstream_provider,
 256                    upstream_status,
 257                    inner_message,
 258                    retry_after,
 259                );
 260            }
 261            anyhow!("completion request failed, code: {code}, message: {message}").into()
 262        } else if let Some(status_code) = code
 263            .strip_prefix("upstream_http_")
 264            .and_then(|code| StatusCode::from_str(code).ok())
 265        {
 266            Self::from_http_status(upstream_provider, status_code, message, retry_after)
 267        } else if let Some(status_code) = code
 268            .strip_prefix("http_")
 269            .and_then(|code| StatusCode::from_str(code).ok())
 270        {
 271            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
 272        } else {
 273            anyhow!("completion request failed, code: {code}, message: {message}").into()
 274        }
 275    }
 276
 277    pub fn from_http_status(
 278        provider: LanguageModelProviderName,
 279        status_code: StatusCode,
 280        message: String,
 281        retry_after: Option<Duration>,
 282    ) -> Self {
 283        match status_code {
 284            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
 285            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
 286            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
 287            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
 288            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
 289                tokens: parse_prompt_too_long(&message),
 290            },
 291            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
 292                provider,
 293                retry_after,
 294            },
 295            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
 296            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
 297                provider,
 298                retry_after,
 299            },
 300            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
 301                provider,
 302                retry_after,
 303            },
 304            _ => Self::HttpResponseError {
 305                provider,
 306                status_code,
 307                message,
 308            },
 309        }
 310    }
 311}
 312
 313impl From<AnthropicError> for LanguageModelCompletionError {
 314    fn from(error: AnthropicError) -> Self {
 315        let provider = ANTHROPIC_PROVIDER_NAME;
 316        match error {
 317            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 318            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 319            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
 320            AnthropicError::DeserializeResponse(error) => {
 321                Self::DeserializeResponse { provider, error }
 322            }
 323            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 324            AnthropicError::HttpResponseError {
 325                status_code,
 326                message,
 327            } => Self::HttpResponseError {
 328                provider,
 329                status_code,
 330                message,
 331            },
 332            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
 333                provider,
 334                retry_after: Some(retry_after),
 335            },
 336            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 337                provider,
 338                retry_after,
 339            },
 340            AnthropicError::ApiError(api_error) => api_error.into(),
 341        }
 342    }
 343}
 344
 345impl From<anthropic::ApiError> for LanguageModelCompletionError {
 346    fn from(error: anthropic::ApiError) -> Self {
 347        use anthropic::ApiErrorCode::*;
 348        let provider = ANTHROPIC_PROVIDER_NAME;
 349        match error.code() {
 350            Some(code) => match code {
 351                InvalidRequestError => Self::BadRequestFormat {
 352                    provider,
 353                    message: error.message,
 354                },
 355                AuthenticationError => Self::AuthenticationError {
 356                    provider,
 357                    message: error.message,
 358                },
 359                PermissionError => Self::PermissionError {
 360                    provider,
 361                    message: error.message,
 362                },
 363                NotFoundError => Self::ApiEndpointNotFound { provider },
 364                RequestTooLarge => Self::PromptTooLarge {
 365                    tokens: parse_prompt_too_long(&error.message),
 366                },
 367                RateLimitError => Self::RateLimitExceeded {
 368                    provider,
 369                    retry_after: None,
 370                },
 371                ApiError => Self::ApiInternalServerError {
 372                    provider,
 373                    message: error.message,
 374                },
 375                OverloadedError => Self::ServerOverloaded {
 376                    provider,
 377                    retry_after: None,
 378                },
 379            },
 380            None => Self::Other(error.into()),
 381        }
 382    }
 383}
 384
 385impl From<open_ai::RequestError> for LanguageModelCompletionError {
 386    fn from(error: open_ai::RequestError) -> Self {
 387        match error {
 388            open_ai::RequestError::HttpResponseError {
 389                provider,
 390                status_code,
 391                body,
 392                headers,
 393            } => {
 394                let retry_after = headers
 395                    .get(http::header::RETRY_AFTER)
 396                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
 397                    .map(Duration::from_secs);
 398
 399                Self::from_http_status(provider.into(), status_code, body, retry_after)
 400            }
 401            open_ai::RequestError::Other(e) => Self::Other(e),
 402        }
 403    }
 404}
 405
 406impl From<OpenRouterError> for LanguageModelCompletionError {
 407    fn from(error: OpenRouterError) -> Self {
 408        let provider = LanguageModelProviderName::new("OpenRouter");
 409        match error {
 410            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 411            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 412            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
 413            OpenRouterError::DeserializeResponse(error) => {
 414                Self::DeserializeResponse { provider, error }
 415            }
 416            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 417            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
 418                provider,
 419                retry_after: Some(retry_after),
 420            },
 421            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 422                provider,
 423                retry_after,
 424            },
 425            OpenRouterError::ApiError(api_error) => api_error.into(),
 426        }
 427    }
 428}
 429
 430impl From<open_router::ApiError> for LanguageModelCompletionError {
 431    fn from(error: open_router::ApiError) -> Self {
 432        use open_router::ApiErrorCode::*;
 433        let provider = LanguageModelProviderName::new("OpenRouter");
 434        match error.code {
 435            InvalidRequestError => Self::BadRequestFormat {
 436                provider,
 437                message: error.message,
 438            },
 439            AuthenticationError => Self::AuthenticationError {
 440                provider,
 441                message: error.message,
 442            },
 443            PaymentRequiredError => Self::AuthenticationError {
 444                provider,
 445                message: format!("Payment required: {}", error.message),
 446            },
 447            PermissionError => Self::PermissionError {
 448                provider,
 449                message: error.message,
 450            },
 451            RequestTimedOut => Self::HttpResponseError {
 452                provider,
 453                status_code: StatusCode::REQUEST_TIMEOUT,
 454                message: error.message,
 455            },
 456            RateLimitError => Self::RateLimitExceeded {
 457                provider,
 458                retry_after: None,
 459            },
 460            ApiError => Self::ApiInternalServerError {
 461                provider,
 462                message: error.message,
 463            },
 464            OverloadedError => Self::ServerOverloaded {
 465                provider,
 466                retry_after: None,
 467            },
 468        }
 469    }
 470}
 471
 472#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 473#[serde(rename_all = "snake_case")]
 474pub enum StopReason {
 475    EndTurn,
 476    MaxTokens,
 477    ToolUse,
 478    Refusal,
 479}
 480
 481#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 482pub struct TokenUsage {
 483    #[serde(default, skip_serializing_if = "is_default")]
 484    pub input_tokens: u64,
 485    #[serde(default, skip_serializing_if = "is_default")]
 486    pub output_tokens: u64,
 487    #[serde(default, skip_serializing_if = "is_default")]
 488    pub cache_creation_input_tokens: u64,
 489    #[serde(default, skip_serializing_if = "is_default")]
 490    pub cache_read_input_tokens: u64,
 491}
 492
 493impl TokenUsage {
 494    pub fn total_tokens(&self) -> u64 {
 495        self.input_tokens
 496            + self.output_tokens
 497            + self.cache_read_input_tokens
 498            + self.cache_creation_input_tokens
 499    }
 500}
 501
 502impl Add<TokenUsage> for TokenUsage {
 503    type Output = Self;
 504
 505    fn add(self, other: Self) -> Self {
 506        Self {
 507            input_tokens: self.input_tokens + other.input_tokens,
 508            output_tokens: self.output_tokens + other.output_tokens,
 509            cache_creation_input_tokens: self.cache_creation_input_tokens
 510                + other.cache_creation_input_tokens,
 511            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
 512        }
 513    }
 514}
 515
 516impl Sub<TokenUsage> for TokenUsage {
 517    type Output = Self;
 518
 519    fn sub(self, other: Self) -> Self {
 520        Self {
 521            input_tokens: self.input_tokens - other.input_tokens,
 522            output_tokens: self.output_tokens - other.output_tokens,
 523            cache_creation_input_tokens: self.cache_creation_input_tokens
 524                - other.cache_creation_input_tokens,
 525            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
 526        }
 527    }
 528}
 529
 530#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 531pub struct LanguageModelToolUseId(Arc<str>);
 532
 533impl fmt::Display for LanguageModelToolUseId {
 534    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 535        write!(f, "{}", self.0)
 536    }
 537}
 538
 539impl<T> From<T> for LanguageModelToolUseId
 540where
 541    T: Into<Arc<str>>,
 542{
 543    fn from(value: T) -> Self {
 544        Self(value.into())
 545    }
 546}
 547
 548#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 549pub struct LanguageModelToolUse {
 550    pub id: LanguageModelToolUseId,
 551    pub name: Arc<str>,
 552    pub raw_input: String,
 553    pub input: serde_json::Value,
 554    pub is_input_complete: bool,
 555    /// Thought signature the model sent us. Some models require that this
 556    /// signature be preserved and sent back in conversation history for validation.
 557    pub thought_signature: Option<String>,
 558}
 559
 560pub struct LanguageModelTextStream {
 561    pub message_id: Option<String>,
 562    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 563    // Has complete token usage after the stream has finished
 564    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 565}
 566
 567impl Default for LanguageModelTextStream {
 568    fn default() -> Self {
 569        Self {
 570            message_id: None,
 571            stream: Box::pin(futures::stream::empty()),
 572            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 573        }
 574    }
 575}
 576
 577#[derive(Debug, Clone)]
 578pub struct LanguageModelEffortLevel {
 579    pub name: SharedString,
 580    pub value: SharedString,
 581    pub is_default: bool,
 582}
 583
 584pub trait LanguageModel: Send + Sync {
 585    fn id(&self) -> LanguageModelId;
 586    fn name(&self) -> LanguageModelName;
 587    fn provider_id(&self) -> LanguageModelProviderId;
 588    fn provider_name(&self) -> LanguageModelProviderName;
 589    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 590        self.provider_id()
 591    }
 592    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 593        self.provider_name()
 594    }
 595
 596    /// Returns whether this model is the "latest", so we can highlight it in the UI.
 597    fn is_latest(&self) -> bool {
 598        false
 599    }
 600
 601    fn telemetry_id(&self) -> String;
 602
 603    fn api_key(&self, _cx: &App) -> Option<String> {
 604        None
 605    }
 606
 607    /// Information about the cost of using this model, if available.
 608    fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
 609        None
 610    }
 611
 612    /// Whether this model supports thinking.
 613    fn supports_thinking(&self) -> bool {
 614        false
 615    }
 616
 617    fn supports_fast_mode(&self) -> bool {
 618        false
 619    }
 620
 621    /// Returns the list of supported effort levels that can be used when thinking.
 622    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 623        Vec::new()
 624    }
 625
 626    /// Returns the default effort level to use when thinking.
 627    fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
 628        self.supported_effort_levels()
 629            .into_iter()
 630            .find(|effort_level| effort_level.is_default)
 631    }
 632
 633    /// Whether this model supports images
 634    fn supports_images(&self) -> bool;
 635
 636    /// Whether this model supports tools.
 637    fn supports_tools(&self) -> bool;
 638
 639    /// Whether this model supports choosing which tool to use.
 640    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
 641
 642    /// Returns whether this model or provider supports streaming tool calls;
 643    fn supports_streaming_tools(&self) -> bool {
 644        false
 645    }
 646
 647    /// Returns whether this model/provider reports accurate split input/output token counts.
 648    /// When true, the UI may show separate input/output token indicators.
 649    fn supports_split_token_display(&self) -> bool {
 650        false
 651    }
 652
 653    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 654        LanguageModelToolSchemaFormat::JsonSchema
 655    }
 656
 657    fn max_token_count(&self) -> u64;
 658    fn max_output_tokens(&self) -> Option<u64> {
 659        None
 660    }
 661
 662    fn count_tokens(
 663        &self,
 664        request: LanguageModelRequest,
 665        cx: &App,
 666    ) -> BoxFuture<'static, Result<u64>>;
 667
 668    fn stream_completion(
 669        &self,
 670        request: LanguageModelRequest,
 671        cx: &AsyncApp,
 672    ) -> BoxFuture<
 673        'static,
 674        Result<
 675            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 676            LanguageModelCompletionError,
 677        >,
 678    >;
 679
 680    fn stream_completion_text(
 681        &self,
 682        request: LanguageModelRequest,
 683        cx: &AsyncApp,
 684    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
 685        let future = self.stream_completion(request, cx);
 686
 687        async move {
 688            let events = future.await?;
 689            let mut events = events.fuse();
 690            let mut message_id = None;
 691            let mut first_item_text = None;
 692            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 693
 694            if let Some(first_event) = events.next().await {
 695                match first_event {
 696                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
 697                        message_id = Some(id);
 698                    }
 699                    Ok(LanguageModelCompletionEvent::Text(text)) => {
 700                        first_item_text = Some(text);
 701                    }
 702                    _ => (),
 703                }
 704            }
 705
 706            let stream = futures::stream::iter(first_item_text.map(Ok))
 707                .chain(events.filter_map({
 708                    let last_token_usage = last_token_usage.clone();
 709                    move |result| {
 710                        let last_token_usage = last_token_usage.clone();
 711                        async move {
 712                            match result {
 713                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
 714                                Ok(LanguageModelCompletionEvent::Started) => None,
 715                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
 716                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
 717                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
 718                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
 719                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
 720                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
 721                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
 722                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 723                                    ..
 724                                }) => None,
 725                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
 726                                    *last_token_usage.lock() = token_usage;
 727                                    None
 728                                }
 729                                Err(err) => Some(Err(err)),
 730                            }
 731                        }
 732                    }
 733                }))
 734                .boxed();
 735
 736            Ok(LanguageModelTextStream {
 737                message_id,
 738                stream,
 739                last_token_usage,
 740            })
 741        }
 742        .boxed()
 743    }
 744
 745    fn stream_completion_tool(
 746        &self,
 747        request: LanguageModelRequest,
 748        cx: &AsyncApp,
 749    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
 750        let future = self.stream_completion(request, cx);
 751
 752        async move {
 753            let events = future.await?;
 754            let mut events = events.fuse();
 755
 756            // Iterate through events until we find a complete ToolUse
 757            while let Some(event) = events.next().await {
 758                match event {
 759                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
 760                        if tool_use.is_input_complete =>
 761                    {
 762                        return Ok(tool_use);
 763                    }
 764                    Err(err) => {
 765                        return Err(err);
 766                    }
 767                    _ => {}
 768                }
 769            }
 770
 771            // Stream ended without a complete tool use
 772            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
 773                "Stream ended without receiving a complete tool use"
 774            )))
 775        }
 776        .boxed()
 777    }
 778
 779    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 780        None
 781    }
 782
 783    #[cfg(any(test, feature = "test-support"))]
 784    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
 785        unimplemented!()
 786    }
 787}
 788
 789impl std::fmt::Debug for dyn LanguageModel {
 790    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 791        f.debug_struct("<dyn LanguageModel>")
 792            .field("id", &self.id())
 793            .field("name", &self.name())
 794            .field("provider_id", &self.provider_id())
 795            .field("provider_name", &self.provider_name())
 796            .field("upstream_provider_name", &self.upstream_provider_name())
 797            .field("upstream_provider_id", &self.upstream_provider_id())
 798            .field("upstream_provider_id", &self.upstream_provider_id())
 799            .field("supports_streaming_tools", &self.supports_streaming_tools())
 800            .finish()
 801    }
 802}
 803
 804/// An error that occurred when trying to authenticate the language model provider.
 805#[derive(Debug, Error)]
 806pub enum AuthenticateError {
 807    #[error("connection refused")]
 808    ConnectionRefused,
 809    #[error("credentials not found")]
 810    CredentialsNotFound,
 811    #[error(transparent)]
 812    Other(#[from] anyhow::Error),
 813}
 814
 815/// Either a built-in icon name or a path to an external SVG.
 816#[derive(Debug, Clone, PartialEq, Eq)]
 817pub enum IconOrSvg {
 818    /// A built-in icon from Zed's icon set.
 819    Icon(IconName),
 820    /// Path to a custom SVG icon file.
 821    Svg(SharedString),
 822}
 823
 824impl Default for IconOrSvg {
 825    fn default() -> Self {
 826        Self::Icon(IconName::ZedAssistant)
 827    }
 828}
 829
 830pub trait LanguageModelProvider: 'static {
 831    fn id(&self) -> LanguageModelProviderId;
 832    fn name(&self) -> LanguageModelProviderName;
 833    fn icon(&self) -> IconOrSvg {
 834        IconOrSvg::default()
 835    }
 836    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 837    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 838    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
 839    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 840        Vec::new()
 841    }
 842    fn is_authenticated(&self, cx: &App) -> bool;
 843    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
 844    fn configuration_view(
 845        &self,
 846        target_agent: ConfigurationViewTargetAgent,
 847        window: &mut Window,
 848        cx: &mut App,
 849    ) -> AnyView;
 850    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
 851}
 852
 853#[derive(Default, Clone, PartialEq, Eq)]
 854pub enum ConfigurationViewTargetAgent {
 855    #[default]
 856    ZedAgent,
 857    Other(SharedString),
 858}
 859
 860#[derive(PartialEq, Eq)]
 861pub enum LanguageModelProviderTosView {
 862    /// When there are some past interactions in the Agent Panel.
 863    ThreadEmptyState,
 864    /// When there are no past interactions in the Agent Panel.
 865    ThreadFreshStart,
 866    TextThreadPopup,
 867    Configuration,
 868}
 869
 870pub trait LanguageModelProviderState: 'static {
 871    type ObservableEntity;
 872
 873    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
 874
 875    fn subscribe<T: 'static>(
 876        &self,
 877        cx: &mut gpui::Context<T>,
 878        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
 879    ) -> Option<gpui::Subscription> {
 880        let entity = self.observable_entity()?;
 881        Some(cx.observe(&entity, move |this, _, cx| {
 882            callback(this, cx);
 883        }))
 884    }
 885}
 886
 887#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
 888pub struct LanguageModelId(pub SharedString);
 889
 890#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 891pub struct LanguageModelName(pub SharedString);
 892
 893#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 894pub struct LanguageModelProviderId(pub SharedString);
 895
 896#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 897pub struct LanguageModelProviderName(pub SharedString);
 898
 899#[derive(Clone, Debug, PartialEq)]
 900pub enum LanguageModelCostInfo {
 901    /// Cost per 1,000 input and output tokens
 902    TokenCost {
 903        input_token_cost_per_1m: f64,
 904        output_token_cost_per_1m: f64,
 905    },
 906    /// Cost per request
 907    RequestCost { cost_per_request: f64 },
 908}
 909
 910impl LanguageModelCostInfo {
 911    pub fn to_shared_string(&self) -> SharedString {
 912        match self {
 913            LanguageModelCostInfo::RequestCost { cost_per_request } => {
 914                let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
 915                SharedString::from(cost_str)
 916            }
 917            LanguageModelCostInfo::TokenCost {
 918                input_token_cost_per_1m,
 919                output_token_cost_per_1m,
 920            } => {
 921                let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
 922                let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
 923                SharedString::from(format!("{}$/{}$", input_cost, output_cost))
 924            }
 925        }
 926    }
 927
 928    fn cost_value_to_string(cost: &f64) -> SharedString {
 929        if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
 930            SharedString::from(format!("{:.0}", cost))
 931        } else {
 932            SharedString::from(format!("{:.2}", cost))
 933        }
 934    }
 935}
 936
 937impl LanguageModelProviderId {
 938    pub const fn new(id: &'static str) -> Self {
 939        Self(SharedString::new_static(id))
 940    }
 941}
 942
 943impl LanguageModelProviderName {
 944    pub const fn new(id: &'static str) -> Self {
 945        Self(SharedString::new_static(id))
 946    }
 947}
 948
 949impl fmt::Display for LanguageModelProviderId {
 950    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 951        write!(f, "{}", self.0)
 952    }
 953}
 954
 955impl fmt::Display for LanguageModelProviderName {
 956    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 957        write!(f, "{}", self.0)
 958    }
 959}
 960
 961impl From<String> for LanguageModelId {
 962    fn from(value: String) -> Self {
 963        Self(SharedString::from(value))
 964    }
 965}
 966
 967impl From<String> for LanguageModelName {
 968    fn from(value: String) -> Self {
 969        Self(SharedString::from(value))
 970    }
 971}
 972
 973impl From<String> for LanguageModelProviderId {
 974    fn from(value: String) -> Self {
 975        Self(SharedString::from(value))
 976    }
 977}
 978
 979impl From<String> for LanguageModelProviderName {
 980    fn from(value: String) -> Self {
 981        Self(SharedString::from(value))
 982    }
 983}
 984
 985impl From<Arc<str>> for LanguageModelProviderId {
 986    fn from(value: Arc<str>) -> Self {
 987        Self(SharedString::from(value))
 988    }
 989}
 990
 991impl From<Arc<str>> for LanguageModelProviderName {
 992    fn from(value: Arc<str>) -> Self {
 993        Self(SharedString::from(value))
 994    }
 995}
 996
 997#[cfg(test)]
 998mod tests {
 999    use super::*;
1000
1001    #[test]
1002    fn test_from_cloud_failure_with_upstream_http_error() {
1003        let error = LanguageModelCompletionError::from_cloud_failure(
1004            String::from("anthropic").into(),
1005            "upstream_http_error".to_string(),
1006            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1007            None,
1008        );
1009
1010        match error {
1011            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1012                assert_eq!(provider.0, "anthropic");
1013            }
1014            _ => panic!(
1015                "Expected ServerOverloaded error for 503 status, got: {:?}",
1016                error
1017            ),
1018        }
1019
1020        let error = LanguageModelCompletionError::from_cloud_failure(
1021            String::from("anthropic").into(),
1022            "upstream_http_error".to_string(),
1023            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
1024            None,
1025        );
1026
1027        match error {
1028            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1029                assert_eq!(provider.0, "anthropic");
1030                assert_eq!(message, "Internal server error");
1031            }
1032            _ => panic!(
1033                "Expected ApiInternalServerError for 500 status, got: {:?}",
1034                error
1035            ),
1036        }
1037    }
1038
1039    #[test]
1040    fn test_from_cloud_failure_with_standard_format() {
1041        let error = LanguageModelCompletionError::from_cloud_failure(
1042            String::from("anthropic").into(),
1043            "upstream_http_503".to_string(),
1044            "Service unavailable".to_string(),
1045            None,
1046        );
1047
1048        match error {
1049            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1050                assert_eq!(provider.0, "anthropic");
1051            }
1052            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
1053        }
1054    }
1055
1056    #[test]
1057    fn test_upstream_http_error_connection_timeout() {
1058        let error = LanguageModelCompletionError::from_cloud_failure(
1059            String::from("anthropic").into(),
1060            "upstream_http_error".to_string(),
1061            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1062            None,
1063        );
1064
1065        match error {
1066            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1067                assert_eq!(provider.0, "anthropic");
1068            }
1069            _ => panic!(
1070                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
1071                error
1072            ),
1073        }
1074
1075        let error = LanguageModelCompletionError::from_cloud_failure(
1076            String::from("anthropic").into(),
1077            "upstream_http_error".to_string(),
1078            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
1079            None,
1080        );
1081
1082        match error {
1083            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1084                assert_eq!(provider.0, "anthropic");
1085                assert_eq!(
1086                    message,
1087                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
1088                );
1089            }
1090            _ => panic!(
1091                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1092                error
1093            ),
1094        }
1095    }
1096
1097    #[test]
1098    fn test_language_model_tool_use_serializes_with_signature() {
1099        use serde_json::json;
1100
1101        let tool_use = LanguageModelToolUse {
1102            id: LanguageModelToolUseId::from("test_id"),
1103            name: "test_tool".into(),
1104            raw_input: json!({"arg": "value"}).to_string(),
1105            input: json!({"arg": "value"}),
1106            is_input_complete: true,
1107            thought_signature: Some("test_signature".to_string()),
1108        };
1109
1110        let serialized = serde_json::to_value(&tool_use).unwrap();
1111
1112        assert_eq!(serialized["id"], "test_id");
1113        assert_eq!(serialized["name"], "test_tool");
1114        assert_eq!(serialized["thought_signature"], "test_signature");
1115    }
1116
1117    #[test]
1118    fn test_language_model_tool_use_deserializes_with_missing_signature() {
1119        use serde_json::json;
1120
1121        let json = json!({
1122            "id": "test_id",
1123            "name": "test_tool",
1124            "raw_input": "{\"arg\":\"value\"}",
1125            "input": {"arg": "value"},
1126            "is_input_complete": true
1127        });
1128
1129        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1130
1131        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1132        assert_eq!(tool_use.name.as_ref(), "test_tool");
1133        assert_eq!(tool_use.thought_signature, None);
1134    }
1135
1136    #[test]
1137    fn test_language_model_tool_use_round_trip_with_signature() {
1138        use serde_json::json;
1139
1140        let original = LanguageModelToolUse {
1141            id: LanguageModelToolUseId::from("round_trip_id"),
1142            name: "round_trip_tool".into(),
1143            raw_input: json!({"key": "value"}).to_string(),
1144            input: json!({"key": "value"}),
1145            is_input_complete: true,
1146            thought_signature: Some("round_trip_sig".to_string()),
1147        };
1148
1149        let serialized = serde_json::to_value(&original).unwrap();
1150        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1151
1152        assert_eq!(deserialized.id, original.id);
1153        assert_eq!(deserialized.name, original.name);
1154        assert_eq!(deserialized.thought_signature, original.thought_signature);
1155    }
1156
1157    #[test]
1158    fn test_language_model_tool_use_round_trip_without_signature() {
1159        use serde_json::json;
1160
1161        let original = LanguageModelToolUse {
1162            id: LanguageModelToolUseId::from("no_sig_id"),
1163            name: "no_sig_tool".into(),
1164            raw_input: json!({"arg": "value"}).to_string(),
1165            input: json!({"arg": "value"}),
1166            is_input_complete: true,
1167            thought_signature: None,
1168        };
1169
1170        let serialized = serde_json::to_value(&original).unwrap();
1171        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1172
1173        assert_eq!(deserialized.id, original.id);
1174        assert_eq!(deserialized.name, original.name);
1175        assert_eq!(deserialized.thought_signature, None);
1176    }
1177}