language_model.rs

   1mod api_key;
   2mod model;
   3mod rate_limiter;
   4mod registry;
   5mod request;
   6mod role;
   7mod telemetry;
   8pub mod tool_schema;
   9
  10#[cfg(any(test, feature = "test-support"))]
  11pub mod fake_provider;
  12
  13use anthropic::{AnthropicError, parse_prompt_too_long};
  14use anyhow::{Result, anyhow};
  15use client::Client;
  16use cloud_llm_client::CompletionRequestStatus;
  17use futures::FutureExt;
  18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
  19use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
  20use http_client::{StatusCode, http};
  21use icons::IconName;
  22use open_router::OpenRouterError;
  23use parking_lot::Mutex;
  24use serde::{Deserialize, Serialize};
  25pub use settings::LanguageModelCacheConfiguration;
  26use std::ops::{Add, Sub};
  27use std::str::FromStr;
  28use std::sync::Arc;
  29use std::time::Duration;
  30use std::{fmt, io};
  31use thiserror::Error;
  32use util::serde::is_default;
  33
  34pub use crate::api_key::{ApiKey, ApiKeyState};
  35pub use crate::model::*;
  36pub use crate::rate_limiter::*;
  37pub use crate::registry::*;
  38pub use crate::request::*;
  39pub use crate::role::*;
  40pub use crate::telemetry::*;
  41pub use crate::tool_schema::LanguageModelToolSchemaFormat;
  42pub use zed_env_vars::{EnvVar, env_var};
  43
  44pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
  45    LanguageModelProviderId::new("anthropic");
  46pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
  47    LanguageModelProviderName::new("Anthropic");
  48
  49pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
  50pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
  51    LanguageModelProviderName::new("Google AI");
  52
  53pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
  54pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
  55    LanguageModelProviderName::new("OpenAI");
  56
  57pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
  58pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
  59
  60pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
  61pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
  62    LanguageModelProviderName::new("Zed");
  63
  64pub fn init(client: Arc<Client>, cx: &mut App) {
  65    init_settings(cx);
  66    RefreshLlmTokenListener::register(client, cx);
  67}
  68
  69pub fn init_settings(cx: &mut App) {
  70    registry::init(cx);
  71}
  72
  73/// A completion event from a language model.
  74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
  75pub enum LanguageModelCompletionEvent {
  76    Queued {
  77        position: usize,
  78    },
  79    Started,
  80    Stop(StopReason),
  81    Text(String),
  82    Thinking {
  83        text: String,
  84        signature: Option<String>,
  85    },
  86    RedactedThinking {
  87        data: String,
  88    },
  89    ToolUse(LanguageModelToolUse),
  90    ToolUseJsonParseError {
  91        id: LanguageModelToolUseId,
  92        tool_name: Arc<str>,
  93        raw_input: Arc<str>,
  94        json_parse_error: String,
  95    },
  96    StartMessage {
  97        message_id: String,
  98    },
  99    ReasoningDetails(serde_json::Value),
 100    UsageUpdate(TokenUsage),
 101}
 102
 103impl LanguageModelCompletionEvent {
 104    pub fn from_completion_request_status(
 105        status: CompletionRequestStatus,
 106        upstream_provider: LanguageModelProviderName,
 107    ) -> Result<Option<Self>, LanguageModelCompletionError> {
 108        match status {
 109            CompletionRequestStatus::Queued { position } => {
 110                Ok(Some(LanguageModelCompletionEvent::Queued { position }))
 111            }
 112            CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
 113            CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
 114            CompletionRequestStatus::Failed {
 115                code,
 116                message,
 117                request_id: _,
 118                retry_after,
 119            } => Err(LanguageModelCompletionError::from_cloud_failure(
 120                upstream_provider,
 121                code,
 122                message,
 123                retry_after.map(Duration::from_secs_f64),
 124            )),
 125        }
 126    }
 127}
 128
 129#[derive(Error, Debug)]
 130pub enum LanguageModelCompletionError {
 131    #[error("prompt too large for context window")]
 132    PromptTooLarge { tokens: Option<u64> },
 133    #[error("missing {provider} API key")]
 134    NoApiKey { provider: LanguageModelProviderName },
 135    #[error("{provider}'s API rate limit exceeded")]
 136    RateLimitExceeded {
 137        provider: LanguageModelProviderName,
 138        retry_after: Option<Duration>,
 139    },
 140    #[error("{provider}'s API servers are overloaded right now")]
 141    ServerOverloaded {
 142        provider: LanguageModelProviderName,
 143        retry_after: Option<Duration>,
 144    },
 145    #[error("{provider}'s API server reported an internal server error: {message}")]
 146    ApiInternalServerError {
 147        provider: LanguageModelProviderName,
 148        message: String,
 149    },
 150    #[error("{message}")]
 151    UpstreamProviderError {
 152        message: String,
 153        status: StatusCode,
 154        retry_after: Option<Duration>,
 155    },
 156    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
 157    HttpResponseError {
 158        provider: LanguageModelProviderName,
 159        status_code: StatusCode,
 160        message: String,
 161    },
 162
 163    // Client errors
 164    #[error("invalid request format to {provider}'s API: {message}")]
 165    BadRequestFormat {
 166        provider: LanguageModelProviderName,
 167        message: String,
 168    },
 169    #[error("authentication error with {provider}'s API: {message}")]
 170    AuthenticationError {
 171        provider: LanguageModelProviderName,
 172        message: String,
 173    },
 174    #[error("Permission error with {provider}'s API: {message}")]
 175    PermissionError {
 176        provider: LanguageModelProviderName,
 177        message: String,
 178    },
 179    #[error("language model provider API endpoint not found")]
 180    ApiEndpointNotFound { provider: LanguageModelProviderName },
 181    #[error("I/O error reading response from {provider}'s API")]
 182    ApiReadResponseError {
 183        provider: LanguageModelProviderName,
 184        #[source]
 185        error: io::Error,
 186    },
 187    #[error("error serializing request to {provider} API")]
 188    SerializeRequest {
 189        provider: LanguageModelProviderName,
 190        #[source]
 191        error: serde_json::Error,
 192    },
 193    #[error("error building request body to {provider} API")]
 194    BuildRequestBody {
 195        provider: LanguageModelProviderName,
 196        #[source]
 197        error: http::Error,
 198    },
 199    #[error("error sending HTTP request to {provider} API")]
 200    HttpSend {
 201        provider: LanguageModelProviderName,
 202        #[source]
 203        error: anyhow::Error,
 204    },
 205    #[error("error deserializing {provider} API response")]
 206    DeserializeResponse {
 207        provider: LanguageModelProviderName,
 208        #[source]
 209        error: serde_json::Error,
 210    },
 211
 212    #[error("stream from {provider} ended unexpectedly")]
 213    StreamEndedUnexpectedly { provider: LanguageModelProviderName },
 214
 215    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
 216    #[error(transparent)]
 217    Other(#[from] anyhow::Error),
 218}
 219
 220impl LanguageModelCompletionError {
 221    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
 222        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
 223        let upstream_status = error_json
 224            .get("upstream_status")
 225            .and_then(|v| v.as_u64())
 226            .and_then(|status| u16::try_from(status).ok())
 227            .and_then(|status| StatusCode::from_u16(status).ok())?;
 228        let inner_message = error_json
 229            .get("message")
 230            .and_then(|v| v.as_str())
 231            .unwrap_or(message)
 232            .to_string();
 233        Some((upstream_status, inner_message))
 234    }
 235
 236    pub fn from_cloud_failure(
 237        upstream_provider: LanguageModelProviderName,
 238        code: String,
 239        message: String,
 240        retry_after: Option<Duration>,
 241    ) -> Self {
 242        if let Some(tokens) = parse_prompt_too_long(&message) {
 243            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
 244            // to be reported. This is a temporary workaround to handle this in the case where the
 245            // token limit has been exceeded.
 246            Self::PromptTooLarge {
 247                tokens: Some(tokens),
 248            }
 249        } else if code == "upstream_http_error" {
 250            if let Some((upstream_status, inner_message)) =
 251                Self::parse_upstream_error_json(&message)
 252            {
 253                return Self::from_http_status(
 254                    upstream_provider,
 255                    upstream_status,
 256                    inner_message,
 257                    retry_after,
 258                );
 259            }
 260            anyhow!("completion request failed, code: {code}, message: {message}").into()
 261        } else if let Some(status_code) = code
 262            .strip_prefix("upstream_http_")
 263            .and_then(|code| StatusCode::from_str(code).ok())
 264        {
 265            Self::from_http_status(upstream_provider, status_code, message, retry_after)
 266        } else if let Some(status_code) = code
 267            .strip_prefix("http_")
 268            .and_then(|code| StatusCode::from_str(code).ok())
 269        {
 270            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
 271        } else {
 272            anyhow!("completion request failed, code: {code}, message: {message}").into()
 273        }
 274    }
 275
 276    pub fn from_http_status(
 277        provider: LanguageModelProviderName,
 278        status_code: StatusCode,
 279        message: String,
 280        retry_after: Option<Duration>,
 281    ) -> Self {
 282        match status_code {
 283            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
 284            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
 285            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
 286            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
 287            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
 288                tokens: parse_prompt_too_long(&message),
 289            },
 290            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
 291                provider,
 292                retry_after,
 293            },
 294            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
 295            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
 296                provider,
 297                retry_after,
 298            },
 299            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
 300                provider,
 301                retry_after,
 302            },
 303            _ => Self::HttpResponseError {
 304                provider,
 305                status_code,
 306                message,
 307            },
 308        }
 309    }
 310}
 311
 312impl From<AnthropicError> for LanguageModelCompletionError {
 313    fn from(error: AnthropicError) -> Self {
 314        let provider = ANTHROPIC_PROVIDER_NAME;
 315        match error {
 316            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 317            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 318            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
 319            AnthropicError::DeserializeResponse(error) => {
 320                Self::DeserializeResponse { provider, error }
 321            }
 322            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 323            AnthropicError::HttpResponseError {
 324                status_code,
 325                message,
 326            } => Self::HttpResponseError {
 327                provider,
 328                status_code,
 329                message,
 330            },
 331            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
 332                provider,
 333                retry_after: Some(retry_after),
 334            },
 335            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 336                provider,
 337                retry_after,
 338            },
 339            AnthropicError::ApiError(api_error) => api_error.into(),
 340        }
 341    }
 342}
 343
 344impl From<anthropic::ApiError> for LanguageModelCompletionError {
 345    fn from(error: anthropic::ApiError) -> Self {
 346        use anthropic::ApiErrorCode::*;
 347        let provider = ANTHROPIC_PROVIDER_NAME;
 348        match error.code() {
 349            Some(code) => match code {
 350                InvalidRequestError => Self::BadRequestFormat {
 351                    provider,
 352                    message: error.message,
 353                },
 354                AuthenticationError => Self::AuthenticationError {
 355                    provider,
 356                    message: error.message,
 357                },
 358                PermissionError => Self::PermissionError {
 359                    provider,
 360                    message: error.message,
 361                },
 362                NotFoundError => Self::ApiEndpointNotFound { provider },
 363                RequestTooLarge => Self::PromptTooLarge {
 364                    tokens: parse_prompt_too_long(&error.message),
 365                },
 366                RateLimitError => Self::RateLimitExceeded {
 367                    provider,
 368                    retry_after: None,
 369                },
 370                ApiError => Self::ApiInternalServerError {
 371                    provider,
 372                    message: error.message,
 373                },
 374                OverloadedError => Self::ServerOverloaded {
 375                    provider,
 376                    retry_after: None,
 377                },
 378            },
 379            None => Self::Other(error.into()),
 380        }
 381    }
 382}
 383
 384impl From<open_ai::RequestError> for LanguageModelCompletionError {
 385    fn from(error: open_ai::RequestError) -> Self {
 386        match error {
 387            open_ai::RequestError::HttpResponseError {
 388                provider,
 389                status_code,
 390                body,
 391                headers,
 392            } => {
 393                let retry_after = headers
 394                    .get(http::header::RETRY_AFTER)
 395                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
 396                    .map(Duration::from_secs);
 397
 398                Self::from_http_status(provider.into(), status_code, body, retry_after)
 399            }
 400            open_ai::RequestError::Other(e) => Self::Other(e),
 401        }
 402    }
 403}
 404
 405impl From<OpenRouterError> for LanguageModelCompletionError {
 406    fn from(error: OpenRouterError) -> Self {
 407        let provider = LanguageModelProviderName::new("OpenRouter");
 408        match error {
 409            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 410            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 411            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
 412            OpenRouterError::DeserializeResponse(error) => {
 413                Self::DeserializeResponse { provider, error }
 414            }
 415            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 416            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
 417                provider,
 418                retry_after: Some(retry_after),
 419            },
 420            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 421                provider,
 422                retry_after,
 423            },
 424            OpenRouterError::ApiError(api_error) => api_error.into(),
 425        }
 426    }
 427}
 428
 429impl From<open_router::ApiError> for LanguageModelCompletionError {
 430    fn from(error: open_router::ApiError) -> Self {
 431        use open_router::ApiErrorCode::*;
 432        let provider = LanguageModelProviderName::new("OpenRouter");
 433        match error.code {
 434            InvalidRequestError => Self::BadRequestFormat {
 435                provider,
 436                message: error.message,
 437            },
 438            AuthenticationError => Self::AuthenticationError {
 439                provider,
 440                message: error.message,
 441            },
 442            PaymentRequiredError => Self::AuthenticationError {
 443                provider,
 444                message: format!("Payment required: {}", error.message),
 445            },
 446            PermissionError => Self::PermissionError {
 447                provider,
 448                message: error.message,
 449            },
 450            RequestTimedOut => Self::HttpResponseError {
 451                provider,
 452                status_code: StatusCode::REQUEST_TIMEOUT,
 453                message: error.message,
 454            },
 455            RateLimitError => Self::RateLimitExceeded {
 456                provider,
 457                retry_after: None,
 458            },
 459            ApiError => Self::ApiInternalServerError {
 460                provider,
 461                message: error.message,
 462            },
 463            OverloadedError => Self::ServerOverloaded {
 464                provider,
 465                retry_after: None,
 466            },
 467        }
 468    }
 469}
 470
 471#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 472#[serde(rename_all = "snake_case")]
 473pub enum StopReason {
 474    EndTurn,
 475    MaxTokens,
 476    ToolUse,
 477    Refusal,
 478}
 479
 480#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 481pub struct TokenUsage {
 482    #[serde(default, skip_serializing_if = "is_default")]
 483    pub input_tokens: u64,
 484    #[serde(default, skip_serializing_if = "is_default")]
 485    pub output_tokens: u64,
 486    #[serde(default, skip_serializing_if = "is_default")]
 487    pub cache_creation_input_tokens: u64,
 488    #[serde(default, skip_serializing_if = "is_default")]
 489    pub cache_read_input_tokens: u64,
 490}
 491
 492impl TokenUsage {
 493    pub fn total_tokens(&self) -> u64 {
 494        self.input_tokens
 495            + self.output_tokens
 496            + self.cache_read_input_tokens
 497            + self.cache_creation_input_tokens
 498    }
 499}
 500
 501impl Add<TokenUsage> for TokenUsage {
 502    type Output = Self;
 503
 504    fn add(self, other: Self) -> Self {
 505        Self {
 506            input_tokens: self.input_tokens + other.input_tokens,
 507            output_tokens: self.output_tokens + other.output_tokens,
 508            cache_creation_input_tokens: self.cache_creation_input_tokens
 509                + other.cache_creation_input_tokens,
 510            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
 511        }
 512    }
 513}
 514
 515impl Sub<TokenUsage> for TokenUsage {
 516    type Output = Self;
 517
 518    fn sub(self, other: Self) -> Self {
 519        Self {
 520            input_tokens: self.input_tokens - other.input_tokens,
 521            output_tokens: self.output_tokens - other.output_tokens,
 522            cache_creation_input_tokens: self.cache_creation_input_tokens
 523                - other.cache_creation_input_tokens,
 524            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
 525        }
 526    }
 527}
 528
 529#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 530pub struct LanguageModelToolUseId(Arc<str>);
 531
 532impl fmt::Display for LanguageModelToolUseId {
 533    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 534        write!(f, "{}", self.0)
 535    }
 536}
 537
 538impl<T> From<T> for LanguageModelToolUseId
 539where
 540    T: Into<Arc<str>>,
 541{
 542    fn from(value: T) -> Self {
 543        Self(value.into())
 544    }
 545}
 546
 547#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 548pub struct LanguageModelToolUse {
 549    pub id: LanguageModelToolUseId,
 550    pub name: Arc<str>,
 551    pub raw_input: String,
 552    pub input: serde_json::Value,
 553    pub is_input_complete: bool,
 554    /// Thought signature the model sent us. Some models require that this
 555    /// signature be preserved and sent back in conversation history for validation.
 556    pub thought_signature: Option<String>,
 557}
 558
 559pub struct LanguageModelTextStream {
 560    pub message_id: Option<String>,
 561    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 562    // Has complete token usage after the stream has finished
 563    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 564}
 565
 566impl Default for LanguageModelTextStream {
 567    fn default() -> Self {
 568        Self {
 569            message_id: None,
 570            stream: Box::pin(futures::stream::empty()),
 571            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 572        }
 573    }
 574}
 575
 576#[derive(Debug, Clone)]
 577pub struct LanguageModelEffortLevel {
 578    pub name: SharedString,
 579    pub value: SharedString,
 580    pub is_default: bool,
 581}
 582
 583pub trait LanguageModel: Send + Sync {
 584    fn id(&self) -> LanguageModelId;
 585    fn name(&self) -> LanguageModelName;
 586    fn provider_id(&self) -> LanguageModelProviderId;
 587    fn provider_name(&self) -> LanguageModelProviderName;
 588    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 589        self.provider_id()
 590    }
 591    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 592        self.provider_name()
 593    }
 594
 595    /// Returns whether this model is the "latest", so we can highlight it in the UI.
 596    fn is_latest(&self) -> bool {
 597        false
 598    }
 599
 600    fn telemetry_id(&self) -> String;
 601
 602    fn api_key(&self, _cx: &App) -> Option<String> {
 603        None
 604    }
 605
 606    /// Information about the cost of using this model, if available.
 607    fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
 608        None
 609    }
 610
 611    /// Whether this model supports thinking.
 612    fn supports_thinking(&self) -> bool {
 613        false
 614    }
 615
 616    /// Returns the list of supported effort levels that can be used when thinking.
 617    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 618        Vec::new()
 619    }
 620
 621    /// Returns the default effort level to use when thinking.
 622    fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
 623        self.supported_effort_levels()
 624            .into_iter()
 625            .find(|effort_level| effort_level.is_default)
 626    }
 627
 628    /// Whether this model supports images
 629    fn supports_images(&self) -> bool;
 630
 631    /// Whether this model supports tools.
 632    fn supports_tools(&self) -> bool;
 633
 634    /// Whether this model supports choosing which tool to use.
 635    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
 636
 637    /// Returns whether this model or provider supports streaming tool calls;
 638    fn supports_streaming_tools(&self) -> bool {
 639        false
 640    }
 641
 642    /// Returns whether this model/provider reports accurate split input/output token counts.
 643    /// When true, the UI may show separate input/output token indicators.
 644    fn supports_split_token_display(&self) -> bool {
 645        false
 646    }
 647
 648    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 649        LanguageModelToolSchemaFormat::JsonSchema
 650    }
 651
 652    fn max_token_count(&self) -> u64;
 653    fn max_output_tokens(&self) -> Option<u64> {
 654        None
 655    }
 656
 657    fn count_tokens(
 658        &self,
 659        request: LanguageModelRequest,
 660        cx: &App,
 661    ) -> BoxFuture<'static, Result<u64>>;
 662
 663    fn stream_completion(
 664        &self,
 665        request: LanguageModelRequest,
 666        cx: &AsyncApp,
 667    ) -> BoxFuture<
 668        'static,
 669        Result<
 670            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 671            LanguageModelCompletionError,
 672        >,
 673    >;
 674
 675    fn stream_completion_text(
 676        &self,
 677        request: LanguageModelRequest,
 678        cx: &AsyncApp,
 679    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
 680        let future = self.stream_completion(request, cx);
 681
 682        async move {
 683            let events = future.await?;
 684            let mut events = events.fuse();
 685            let mut message_id = None;
 686            let mut first_item_text = None;
 687            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 688
 689            if let Some(first_event) = events.next().await {
 690                match first_event {
 691                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
 692                        message_id = Some(id);
 693                    }
 694                    Ok(LanguageModelCompletionEvent::Text(text)) => {
 695                        first_item_text = Some(text);
 696                    }
 697                    _ => (),
 698                }
 699            }
 700
 701            let stream = futures::stream::iter(first_item_text.map(Ok))
 702                .chain(events.filter_map({
 703                    let last_token_usage = last_token_usage.clone();
 704                    move |result| {
 705                        let last_token_usage = last_token_usage.clone();
 706                        async move {
 707                            match result {
 708                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
 709                                Ok(LanguageModelCompletionEvent::Started) => None,
 710                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
 711                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
 712                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
 713                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
 714                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
 715                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
 716                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
 717                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 718                                    ..
 719                                }) => None,
 720                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
 721                                    *last_token_usage.lock() = token_usage;
 722                                    None
 723                                }
 724                                Err(err) => Some(Err(err)),
 725                            }
 726                        }
 727                    }
 728                }))
 729                .boxed();
 730
 731            Ok(LanguageModelTextStream {
 732                message_id,
 733                stream,
 734                last_token_usage,
 735            })
 736        }
 737        .boxed()
 738    }
 739
 740    fn stream_completion_tool(
 741        &self,
 742        request: LanguageModelRequest,
 743        cx: &AsyncApp,
 744    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
 745        let future = self.stream_completion(request, cx);
 746
 747        async move {
 748            let events = future.await?;
 749            let mut events = events.fuse();
 750
 751            // Iterate through events until we find a complete ToolUse
 752            while let Some(event) = events.next().await {
 753                match event {
 754                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
 755                        if tool_use.is_input_complete =>
 756                    {
 757                        return Ok(tool_use);
 758                    }
 759                    Err(err) => {
 760                        return Err(err);
 761                    }
 762                    _ => {}
 763                }
 764            }
 765
 766            // Stream ended without a complete tool use
 767            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
 768                "Stream ended without receiving a complete tool use"
 769            )))
 770        }
 771        .boxed()
 772    }
 773
 774    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 775        None
 776    }
 777
 778    #[cfg(any(test, feature = "test-support"))]
 779    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
 780        unimplemented!()
 781    }
 782}
 783
 784impl std::fmt::Debug for dyn LanguageModel {
 785    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 786        f.debug_struct("<dyn LanguageModel>")
 787            .field("id", &self.id())
 788            .field("name", &self.name())
 789            .field("provider_id", &self.provider_id())
 790            .field("provider_name", &self.provider_name())
 791            .field("upstream_provider_name", &self.upstream_provider_name())
 792            .field("upstream_provider_id", &self.upstream_provider_id())
 793            .field("upstream_provider_id", &self.upstream_provider_id())
 794            .field("supports_streaming_tools", &self.supports_streaming_tools())
 795            .finish()
 796    }
 797}
 798
 799/// An error that occurred when trying to authenticate the language model provider.
 800#[derive(Debug, Error)]
 801pub enum AuthenticateError {
 802    #[error("connection refused")]
 803    ConnectionRefused,
 804    #[error("credentials not found")]
 805    CredentialsNotFound,
 806    #[error(transparent)]
 807    Other(#[from] anyhow::Error),
 808}
 809
 810/// Either a built-in icon name or a path to an external SVG.
 811#[derive(Debug, Clone, PartialEq, Eq)]
 812pub enum IconOrSvg {
 813    /// A built-in icon from Zed's icon set.
 814    Icon(IconName),
 815    /// Path to a custom SVG icon file.
 816    Svg(SharedString),
 817}
 818
 819impl Default for IconOrSvg {
 820    fn default() -> Self {
 821        Self::Icon(IconName::ZedAssistant)
 822    }
 823}
 824
 825pub trait LanguageModelProvider: 'static {
 826    fn id(&self) -> LanguageModelProviderId;
 827    fn name(&self) -> LanguageModelProviderName;
 828    fn icon(&self) -> IconOrSvg {
 829        IconOrSvg::default()
 830    }
 831    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 832    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 833    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
 834    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 835        Vec::new()
 836    }
 837    fn is_authenticated(&self, cx: &App) -> bool;
 838    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
 839    fn configuration_view(
 840        &self,
 841        target_agent: ConfigurationViewTargetAgent,
 842        window: &mut Window,
 843        cx: &mut App,
 844    ) -> AnyView;
 845    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
 846}
 847
 848#[derive(Default, Clone, PartialEq, Eq)]
 849pub enum ConfigurationViewTargetAgent {
 850    #[default]
 851    ZedAgent,
 852    Other(SharedString),
 853}
 854
 855#[derive(PartialEq, Eq)]
 856pub enum LanguageModelProviderTosView {
 857    /// When there are some past interactions in the Agent Panel.
 858    ThreadEmptyState,
 859    /// When there are no past interactions in the Agent Panel.
 860    ThreadFreshStart,
 861    TextThreadPopup,
 862    Configuration,
 863}
 864
 865pub trait LanguageModelProviderState: 'static {
 866    type ObservableEntity;
 867
 868    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
 869
 870    fn subscribe<T: 'static>(
 871        &self,
 872        cx: &mut gpui::Context<T>,
 873        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
 874    ) -> Option<gpui::Subscription> {
 875        let entity = self.observable_entity()?;
 876        Some(cx.observe(&entity, move |this, _, cx| {
 877            callback(this, cx);
 878        }))
 879    }
 880}
 881
 882#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
 883pub struct LanguageModelId(pub SharedString);
 884
 885#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 886pub struct LanguageModelName(pub SharedString);
 887
 888#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 889pub struct LanguageModelProviderId(pub SharedString);
 890
 891#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 892pub struct LanguageModelProviderName(pub SharedString);
 893
 894#[derive(Clone, Debug, PartialEq)]
 895pub enum LanguageModelCostInfo {
 896    /// Cost per 1,000 input and output tokens
 897    TokenCost {
 898        input_token_cost_per_1m: f64,
 899        output_token_cost_per_1m: f64,
 900    },
 901    /// Cost per request
 902    RequestCost { cost_per_request: f64 },
 903}
 904
 905impl LanguageModelCostInfo {
 906    pub fn to_shared_string(&self) -> SharedString {
 907        match self {
 908            LanguageModelCostInfo::RequestCost { cost_per_request } => {
 909                let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
 910                SharedString::from(cost_str)
 911            }
 912            LanguageModelCostInfo::TokenCost {
 913                input_token_cost_per_1m,
 914                output_token_cost_per_1m,
 915            } => {
 916                let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
 917                let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
 918                SharedString::from(format!("{}$/{}$", input_cost, output_cost))
 919            }
 920        }
 921    }
 922
 923    fn cost_value_to_string(cost: &f64) -> SharedString {
 924        if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
 925            SharedString::from(format!("{:.0}", cost))
 926        } else {
 927            SharedString::from(format!("{:.2}", cost))
 928        }
 929    }
 930}
 931
 932impl LanguageModelProviderId {
 933    pub const fn new(id: &'static str) -> Self {
 934        Self(SharedString::new_static(id))
 935    }
 936}
 937
 938impl LanguageModelProviderName {
 939    pub const fn new(id: &'static str) -> Self {
 940        Self(SharedString::new_static(id))
 941    }
 942}
 943
 944impl fmt::Display for LanguageModelProviderId {
 945    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 946        write!(f, "{}", self.0)
 947    }
 948}
 949
 950impl fmt::Display for LanguageModelProviderName {
 951    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 952        write!(f, "{}", self.0)
 953    }
 954}
 955
 956impl From<String> for LanguageModelId {
 957    fn from(value: String) -> Self {
 958        Self(SharedString::from(value))
 959    }
 960}
 961
 962impl From<String> for LanguageModelName {
 963    fn from(value: String) -> Self {
 964        Self(SharedString::from(value))
 965    }
 966}
 967
 968impl From<String> for LanguageModelProviderId {
 969    fn from(value: String) -> Self {
 970        Self(SharedString::from(value))
 971    }
 972}
 973
 974impl From<String> for LanguageModelProviderName {
 975    fn from(value: String) -> Self {
 976        Self(SharedString::from(value))
 977    }
 978}
 979
 980impl From<Arc<str>> for LanguageModelProviderId {
 981    fn from(value: Arc<str>) -> Self {
 982        Self(SharedString::from(value))
 983    }
 984}
 985
 986impl From<Arc<str>> for LanguageModelProviderName {
 987    fn from(value: Arc<str>) -> Self {
 988        Self(SharedString::from(value))
 989    }
 990}
 991
 992#[cfg(test)]
 993mod tests {
 994    use super::*;
 995
 996    #[test]
 997    fn test_from_cloud_failure_with_upstream_http_error() {
 998        let error = LanguageModelCompletionError::from_cloud_failure(
 999            String::from("anthropic").into(),
1000            "upstream_http_error".to_string(),
1001            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1002            None,
1003        );
1004
1005        match error {
1006            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1007                assert_eq!(provider.0, "anthropic");
1008            }
1009            _ => panic!(
1010                "Expected ServerOverloaded error for 503 status, got: {:?}",
1011                error
1012            ),
1013        }
1014
1015        let error = LanguageModelCompletionError::from_cloud_failure(
1016            String::from("anthropic").into(),
1017            "upstream_http_error".to_string(),
1018            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
1019            None,
1020        );
1021
1022        match error {
1023            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1024                assert_eq!(provider.0, "anthropic");
1025                assert_eq!(message, "Internal server error");
1026            }
1027            _ => panic!(
1028                "Expected ApiInternalServerError for 500 status, got: {:?}",
1029                error
1030            ),
1031        }
1032    }
1033
1034    #[test]
1035    fn test_from_cloud_failure_with_standard_format() {
1036        let error = LanguageModelCompletionError::from_cloud_failure(
1037            String::from("anthropic").into(),
1038            "upstream_http_503".to_string(),
1039            "Service unavailable".to_string(),
1040            None,
1041        );
1042
1043        match error {
1044            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1045                assert_eq!(provider.0, "anthropic");
1046            }
1047            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
1048        }
1049    }
1050
1051    #[test]
1052    fn test_upstream_http_error_connection_timeout() {
1053        let error = LanguageModelCompletionError::from_cloud_failure(
1054            String::from("anthropic").into(),
1055            "upstream_http_error".to_string(),
1056            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1057            None,
1058        );
1059
1060        match error {
1061            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1062                assert_eq!(provider.0, "anthropic");
1063            }
1064            _ => panic!(
1065                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
1066                error
1067            ),
1068        }
1069
1070        let error = LanguageModelCompletionError::from_cloud_failure(
1071            String::from("anthropic").into(),
1072            "upstream_http_error".to_string(),
1073            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
1074            None,
1075        );
1076
1077        match error {
1078            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1079                assert_eq!(provider.0, "anthropic");
1080                assert_eq!(
1081                    message,
1082                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
1083                );
1084            }
1085            _ => panic!(
1086                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1087                error
1088            ),
1089        }
1090    }
1091
1092    #[test]
1093    fn test_language_model_tool_use_serializes_with_signature() {
1094        use serde_json::json;
1095
1096        let tool_use = LanguageModelToolUse {
1097            id: LanguageModelToolUseId::from("test_id"),
1098            name: "test_tool".into(),
1099            raw_input: json!({"arg": "value"}).to_string(),
1100            input: json!({"arg": "value"}),
1101            is_input_complete: true,
1102            thought_signature: Some("test_signature".to_string()),
1103        };
1104
1105        let serialized = serde_json::to_value(&tool_use).unwrap();
1106
1107        assert_eq!(serialized["id"], "test_id");
1108        assert_eq!(serialized["name"], "test_tool");
1109        assert_eq!(serialized["thought_signature"], "test_signature");
1110    }
1111
1112    #[test]
1113    fn test_language_model_tool_use_deserializes_with_missing_signature() {
1114        use serde_json::json;
1115
1116        let json = json!({
1117            "id": "test_id",
1118            "name": "test_tool",
1119            "raw_input": "{\"arg\":\"value\"}",
1120            "input": {"arg": "value"},
1121            "is_input_complete": true
1122        });
1123
1124        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1125
1126        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1127        assert_eq!(tool_use.name.as_ref(), "test_tool");
1128        assert_eq!(tool_use.thought_signature, None);
1129    }
1130
1131    #[test]
1132    fn test_language_model_tool_use_round_trip_with_signature() {
1133        use serde_json::json;
1134
1135        let original = LanguageModelToolUse {
1136            id: LanguageModelToolUseId::from("round_trip_id"),
1137            name: "round_trip_tool".into(),
1138            raw_input: json!({"key": "value"}).to_string(),
1139            input: json!({"key": "value"}),
1140            is_input_complete: true,
1141            thought_signature: Some("round_trip_sig".to_string()),
1142        };
1143
1144        let serialized = serde_json::to_value(&original).unwrap();
1145        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1146
1147        assert_eq!(deserialized.id, original.id);
1148        assert_eq!(deserialized.name, original.name);
1149        assert_eq!(deserialized.thought_signature, original.thought_signature);
1150    }
1151
1152    #[test]
1153    fn test_language_model_tool_use_round_trip_without_signature() {
1154        use serde_json::json;
1155
1156        let original = LanguageModelToolUse {
1157            id: LanguageModelToolUseId::from("no_sig_id"),
1158            name: "no_sig_tool".into(),
1159            raw_input: json!({"arg": "value"}).to_string(),
1160            input: json!({"arg": "value"}),
1161            is_input_complete: true,
1162            thought_signature: None,
1163        };
1164
1165        let serialized = serde_json::to_value(&original).unwrap();
1166        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1167
1168        assert_eq!(deserialized.id, original.id);
1169        assert_eq!(deserialized.name, original.name);
1170        assert_eq!(deserialized.thought_signature, None);
1171    }
1172}