language_model.rs

   1mod api_key;
   2mod model;
   3mod rate_limiter;
   4mod registry;
   5mod request;
   6mod role;
   7mod telemetry;
   8pub mod tool_schema;
   9
  10#[cfg(any(test, feature = "test-support"))]
  11pub mod fake_provider;
  12
  13use anthropic::{AnthropicError, parse_prompt_too_long};
  14use anyhow::{Result, anyhow};
  15use client::Client;
  16use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
  17use futures::FutureExt;
  18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
  19use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
  20use http_client::{StatusCode, http};
  21use icons::IconName;
  22use open_router::OpenRouterError;
  23use parking_lot::Mutex;
  24use serde::{Deserialize, Serialize};
  25pub use settings::LanguageModelCacheConfiguration;
  26use std::ops::{Add, Sub};
  27use std::str::FromStr;
  28use std::sync::Arc;
  29use std::time::Duration;
  30use std::{fmt, io};
  31use thiserror::Error;
  32use util::serde::is_default;
  33
  34pub use crate::api_key::{ApiKey, ApiKeyState};
  35pub use crate::model::*;
  36pub use crate::rate_limiter::*;
  37pub use crate::registry::*;
  38pub use crate::request::*;
  39pub use crate::role::*;
  40pub use crate::telemetry::*;
  41pub use crate::tool_schema::LanguageModelToolSchemaFormat;
  42pub use zed_env_vars::{EnvVar, env_var};
  43
  44pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
  45    LanguageModelProviderId::new("anthropic");
  46pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
  47    LanguageModelProviderName::new("Anthropic");
  48
  49pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
  50pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
  51    LanguageModelProviderName::new("Google AI");
  52
  53pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
  54pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
  55    LanguageModelProviderName::new("OpenAI");
  56
  57pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
  58pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
  59
  60pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
  61pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
  62    LanguageModelProviderName::new("Zed");
  63
  64pub fn init(client: Arc<Client>, cx: &mut App) {
  65    init_settings(cx);
  66    RefreshLlmTokenListener::register(client, cx);
  67}
  68
  69pub fn init_settings(cx: &mut App) {
  70    registry::init(cx);
  71}
  72
  73/// A completion event from a language model.
  74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
  75pub enum LanguageModelCompletionEvent {
  76    Queued {
  77        position: usize,
  78    },
  79    Started,
  80    UsageUpdated {
  81        amount: usize,
  82        limit: UsageLimit,
  83    },
  84    ToolUseLimitReached,
  85    Stop(StopReason),
  86    Text(String),
  87    Thinking {
  88        text: String,
  89        signature: Option<String>,
  90    },
  91    RedactedThinking {
  92        data: String,
  93    },
  94    ToolUse(LanguageModelToolUse),
  95    ToolUseJsonParseError {
  96        id: LanguageModelToolUseId,
  97        tool_name: Arc<str>,
  98        raw_input: Arc<str>,
  99        json_parse_error: String,
 100    },
 101    StartMessage {
 102        message_id: String,
 103    },
 104    ReasoningDetails(serde_json::Value),
 105    UsageUpdate(TokenUsage),
 106}
 107
 108impl LanguageModelCompletionEvent {
 109    pub fn from_completion_request_status(
 110        status: CompletionRequestStatus,
 111        upstream_provider: LanguageModelProviderName,
 112    ) -> Result<Self, LanguageModelCompletionError> {
 113        match status {
 114            CompletionRequestStatus::Queued { position } => {
 115                Ok(LanguageModelCompletionEvent::Queued { position })
 116            }
 117            CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
 118            CompletionRequestStatus::UsageUpdated { amount, limit } => {
 119                Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit })
 120            }
 121            CompletionRequestStatus::ToolUseLimitReached => {
 122                Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
 123            }
 124            CompletionRequestStatus::Failed {
 125                code,
 126                message,
 127                request_id: _,
 128                retry_after,
 129            } => Err(LanguageModelCompletionError::from_cloud_failure(
 130                upstream_provider,
 131                code,
 132                message,
 133                retry_after.map(Duration::from_secs_f64),
 134            )),
 135        }
 136    }
 137}
 138
 139#[derive(Error, Debug)]
 140pub enum LanguageModelCompletionError {
 141    #[error("prompt too large for context window")]
 142    PromptTooLarge { tokens: Option<u64> },
 143    #[error("missing {provider} API key")]
 144    NoApiKey { provider: LanguageModelProviderName },
 145    #[error("{provider}'s API rate limit exceeded")]
 146    RateLimitExceeded {
 147        provider: LanguageModelProviderName,
 148        retry_after: Option<Duration>,
 149    },
 150    #[error("{provider}'s API servers are overloaded right now")]
 151    ServerOverloaded {
 152        provider: LanguageModelProviderName,
 153        retry_after: Option<Duration>,
 154    },
 155    #[error("{provider}'s API server reported an internal server error: {message}")]
 156    ApiInternalServerError {
 157        provider: LanguageModelProviderName,
 158        message: String,
 159    },
 160    #[error("{message}")]
 161    UpstreamProviderError {
 162        message: String,
 163        status: StatusCode,
 164        retry_after: Option<Duration>,
 165    },
 166    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
 167    HttpResponseError {
 168        provider: LanguageModelProviderName,
 169        status_code: StatusCode,
 170        message: String,
 171    },
 172
 173    // Client errors
 174    #[error("invalid request format to {provider}'s API: {message}")]
 175    BadRequestFormat {
 176        provider: LanguageModelProviderName,
 177        message: String,
 178    },
 179    #[error("authentication error with {provider}'s API: {message}")]
 180    AuthenticationError {
 181        provider: LanguageModelProviderName,
 182        message: String,
 183    },
 184    #[error("Permission error with {provider}'s API: {message}")]
 185    PermissionError {
 186        provider: LanguageModelProviderName,
 187        message: String,
 188    },
 189    #[error("language model provider API endpoint not found")]
 190    ApiEndpointNotFound { provider: LanguageModelProviderName },
 191    #[error("I/O error reading response from {provider}'s API")]
 192    ApiReadResponseError {
 193        provider: LanguageModelProviderName,
 194        #[source]
 195        error: io::Error,
 196    },
 197    #[error("error serializing request to {provider} API")]
 198    SerializeRequest {
 199        provider: LanguageModelProviderName,
 200        #[source]
 201        error: serde_json::Error,
 202    },
 203    #[error("error building request body to {provider} API")]
 204    BuildRequestBody {
 205        provider: LanguageModelProviderName,
 206        #[source]
 207        error: http::Error,
 208    },
 209    #[error("error sending HTTP request to {provider} API")]
 210    HttpSend {
 211        provider: LanguageModelProviderName,
 212        #[source]
 213        error: anyhow::Error,
 214    },
 215    #[error("error deserializing {provider} API response")]
 216    DeserializeResponse {
 217        provider: LanguageModelProviderName,
 218        #[source]
 219        error: serde_json::Error,
 220    },
 221
 222    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
 223    #[error(transparent)]
 224    Other(#[from] anyhow::Error),
 225}
 226
 227impl LanguageModelCompletionError {
 228    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
 229        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
 230        let upstream_status = error_json
 231            .get("upstream_status")
 232            .and_then(|v| v.as_u64())
 233            .and_then(|status| u16::try_from(status).ok())
 234            .and_then(|status| StatusCode::from_u16(status).ok())?;
 235        let inner_message = error_json
 236            .get("message")
 237            .and_then(|v| v.as_str())
 238            .unwrap_or(message)
 239            .to_string();
 240        Some((upstream_status, inner_message))
 241    }
 242
 243    pub fn from_cloud_failure(
 244        upstream_provider: LanguageModelProviderName,
 245        code: String,
 246        message: String,
 247        retry_after: Option<Duration>,
 248    ) -> Self {
 249        if let Some(tokens) = parse_prompt_too_long(&message) {
 250            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
 251            // to be reported. This is a temporary workaround to handle this in the case where the
 252            // token limit has been exceeded.
 253            Self::PromptTooLarge {
 254                tokens: Some(tokens),
 255            }
 256        } else if code == "upstream_http_error" {
 257            if let Some((upstream_status, inner_message)) =
 258                Self::parse_upstream_error_json(&message)
 259            {
 260                return Self::from_http_status(
 261                    upstream_provider,
 262                    upstream_status,
 263                    inner_message,
 264                    retry_after,
 265                );
 266            }
 267            anyhow!("completion request failed, code: {code}, message: {message}").into()
 268        } else if let Some(status_code) = code
 269            .strip_prefix("upstream_http_")
 270            .and_then(|code| StatusCode::from_str(code).ok())
 271        {
 272            Self::from_http_status(upstream_provider, status_code, message, retry_after)
 273        } else if let Some(status_code) = code
 274            .strip_prefix("http_")
 275            .and_then(|code| StatusCode::from_str(code).ok())
 276        {
 277            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
 278        } else {
 279            anyhow!("completion request failed, code: {code}, message: {message}").into()
 280        }
 281    }
 282
 283    pub fn from_http_status(
 284        provider: LanguageModelProviderName,
 285        status_code: StatusCode,
 286        message: String,
 287        retry_after: Option<Duration>,
 288    ) -> Self {
 289        match status_code {
 290            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
 291            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
 292            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
 293            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
 294            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
 295                tokens: parse_prompt_too_long(&message),
 296            },
 297            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
 298                provider,
 299                retry_after,
 300            },
 301            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
 302            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
 303                provider,
 304                retry_after,
 305            },
 306            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
 307                provider,
 308                retry_after,
 309            },
 310            _ => Self::HttpResponseError {
 311                provider,
 312                status_code,
 313                message,
 314            },
 315        }
 316    }
 317}
 318
 319impl From<AnthropicError> for LanguageModelCompletionError {
 320    fn from(error: AnthropicError) -> Self {
 321        let provider = ANTHROPIC_PROVIDER_NAME;
 322        match error {
 323            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 324            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 325            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
 326            AnthropicError::DeserializeResponse(error) => {
 327                Self::DeserializeResponse { provider, error }
 328            }
 329            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 330            AnthropicError::HttpResponseError {
 331                status_code,
 332                message,
 333            } => Self::HttpResponseError {
 334                provider,
 335                status_code,
 336                message,
 337            },
 338            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
 339                provider,
 340                retry_after: Some(retry_after),
 341            },
 342            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 343                provider,
 344                retry_after,
 345            },
 346            AnthropicError::ApiError(api_error) => api_error.into(),
 347        }
 348    }
 349}
 350
 351impl From<anthropic::ApiError> for LanguageModelCompletionError {
 352    fn from(error: anthropic::ApiError) -> Self {
 353        use anthropic::ApiErrorCode::*;
 354        let provider = ANTHROPIC_PROVIDER_NAME;
 355        match error.code() {
 356            Some(code) => match code {
 357                InvalidRequestError => Self::BadRequestFormat {
 358                    provider,
 359                    message: error.message,
 360                },
 361                AuthenticationError => Self::AuthenticationError {
 362                    provider,
 363                    message: error.message,
 364                },
 365                PermissionError => Self::PermissionError {
 366                    provider,
 367                    message: error.message,
 368                },
 369                NotFoundError => Self::ApiEndpointNotFound { provider },
 370                RequestTooLarge => Self::PromptTooLarge {
 371                    tokens: parse_prompt_too_long(&error.message),
 372                },
 373                RateLimitError => Self::RateLimitExceeded {
 374                    provider,
 375                    retry_after: None,
 376                },
 377                ApiError => Self::ApiInternalServerError {
 378                    provider,
 379                    message: error.message,
 380                },
 381                OverloadedError => Self::ServerOverloaded {
 382                    provider,
 383                    retry_after: None,
 384                },
 385            },
 386            None => Self::Other(error.into()),
 387        }
 388    }
 389}
 390
 391impl From<open_ai::RequestError> for LanguageModelCompletionError {
 392    fn from(error: open_ai::RequestError) -> Self {
 393        match error {
 394            open_ai::RequestError::HttpResponseError {
 395                provider,
 396                status_code,
 397                body,
 398                headers,
 399            } => {
 400                let retry_after = headers
 401                    .get(http::header::RETRY_AFTER)
 402                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
 403                    .map(Duration::from_secs);
 404
 405                Self::from_http_status(provider.into(), status_code, body, retry_after)
 406            }
 407            open_ai::RequestError::Other(e) => Self::Other(e),
 408        }
 409    }
 410}
 411
 412impl From<OpenRouterError> for LanguageModelCompletionError {
 413    fn from(error: OpenRouterError) -> Self {
 414        let provider = LanguageModelProviderName::new("OpenRouter");
 415        match error {
 416            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 417            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 418            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
 419            OpenRouterError::DeserializeResponse(error) => {
 420                Self::DeserializeResponse { provider, error }
 421            }
 422            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 423            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
 424                provider,
 425                retry_after: Some(retry_after),
 426            },
 427            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 428                provider,
 429                retry_after,
 430            },
 431            OpenRouterError::ApiError(api_error) => api_error.into(),
 432        }
 433    }
 434}
 435
 436impl From<open_router::ApiError> for LanguageModelCompletionError {
 437    fn from(error: open_router::ApiError) -> Self {
 438        use open_router::ApiErrorCode::*;
 439        let provider = LanguageModelProviderName::new("OpenRouter");
 440        match error.code {
 441            InvalidRequestError => Self::BadRequestFormat {
 442                provider,
 443                message: error.message,
 444            },
 445            AuthenticationError => Self::AuthenticationError {
 446                provider,
 447                message: error.message,
 448            },
 449            PaymentRequiredError => Self::AuthenticationError {
 450                provider,
 451                message: format!("Payment required: {}", error.message),
 452            },
 453            PermissionError => Self::PermissionError {
 454                provider,
 455                message: error.message,
 456            },
 457            RequestTimedOut => Self::HttpResponseError {
 458                provider,
 459                status_code: StatusCode::REQUEST_TIMEOUT,
 460                message: error.message,
 461            },
 462            RateLimitError => Self::RateLimitExceeded {
 463                provider,
 464                retry_after: None,
 465            },
 466            ApiError => Self::ApiInternalServerError {
 467                provider,
 468                message: error.message,
 469            },
 470            OverloadedError => Self::ServerOverloaded {
 471                provider,
 472                retry_after: None,
 473            },
 474        }
 475    }
 476}
 477
 478#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 479#[serde(rename_all = "snake_case")]
 480pub enum StopReason {
 481    EndTurn,
 482    MaxTokens,
 483    ToolUse,
 484    Refusal,
 485}
 486
 487#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 488pub struct TokenUsage {
 489    #[serde(default, skip_serializing_if = "is_default")]
 490    pub input_tokens: u64,
 491    #[serde(default, skip_serializing_if = "is_default")]
 492    pub output_tokens: u64,
 493    #[serde(default, skip_serializing_if = "is_default")]
 494    pub cache_creation_input_tokens: u64,
 495    #[serde(default, skip_serializing_if = "is_default")]
 496    pub cache_read_input_tokens: u64,
 497}
 498
 499impl TokenUsage {
 500    pub fn total_tokens(&self) -> u64 {
 501        self.input_tokens
 502            + self.output_tokens
 503            + self.cache_read_input_tokens
 504            + self.cache_creation_input_tokens
 505    }
 506}
 507
 508impl Add<TokenUsage> for TokenUsage {
 509    type Output = Self;
 510
 511    fn add(self, other: Self) -> Self {
 512        Self {
 513            input_tokens: self.input_tokens + other.input_tokens,
 514            output_tokens: self.output_tokens + other.output_tokens,
 515            cache_creation_input_tokens: self.cache_creation_input_tokens
 516                + other.cache_creation_input_tokens,
 517            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
 518        }
 519    }
 520}
 521
 522impl Sub<TokenUsage> for TokenUsage {
 523    type Output = Self;
 524
 525    fn sub(self, other: Self) -> Self {
 526        Self {
 527            input_tokens: self.input_tokens - other.input_tokens,
 528            output_tokens: self.output_tokens - other.output_tokens,
 529            cache_creation_input_tokens: self.cache_creation_input_tokens
 530                - other.cache_creation_input_tokens,
 531            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
 532        }
 533    }
 534}
 535
 536#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 537pub struct LanguageModelToolUseId(Arc<str>);
 538
 539impl fmt::Display for LanguageModelToolUseId {
 540    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 541        write!(f, "{}", self.0)
 542    }
 543}
 544
 545impl<T> From<T> for LanguageModelToolUseId
 546where
 547    T: Into<Arc<str>>,
 548{
 549    fn from(value: T) -> Self {
 550        Self(value.into())
 551    }
 552}
 553
 554#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 555pub struct LanguageModelToolUse {
 556    pub id: LanguageModelToolUseId,
 557    pub name: Arc<str>,
 558    pub raw_input: String,
 559    pub input: serde_json::Value,
 560    pub is_input_complete: bool,
 561    /// Thought signature the model sent us. Some models require that this
 562    /// signature be preserved and sent back in conversation history for validation.
 563    pub thought_signature: Option<String>,
 564}
 565
 566pub struct LanguageModelTextStream {
 567    pub message_id: Option<String>,
 568    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 569    // Has complete token usage after the stream has finished
 570    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 571}
 572
 573impl Default for LanguageModelTextStream {
 574    fn default() -> Self {
 575        Self {
 576            message_id: None,
 577            stream: Box::pin(futures::stream::empty()),
 578            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 579        }
 580    }
 581}
 582
 583pub trait LanguageModel: Send + Sync {
 584    fn id(&self) -> LanguageModelId;
 585    fn name(&self) -> LanguageModelName;
 586    fn provider_id(&self) -> LanguageModelProviderId;
 587    fn provider_name(&self) -> LanguageModelProviderName;
 588    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 589        self.provider_id()
 590    }
 591    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 592        self.provider_name()
 593    }
 594
 595    fn telemetry_id(&self) -> String;
 596
 597    fn api_key(&self, _cx: &App) -> Option<String> {
 598        None
 599    }
 600
 601    /// Whether this model supports images
 602    fn supports_images(&self) -> bool;
 603
 604    /// Whether this model supports tools.
 605    fn supports_tools(&self) -> bool;
 606
 607    /// Whether this model supports choosing which tool to use.
 608    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
 609
 610    /// Returns whether this model supports "burn mode";
 611    fn supports_burn_mode(&self) -> bool {
 612        false
 613    }
 614
 615    /// Returns whether this model or provider supports streaming tool calls;
 616    fn supports_streaming_tools(&self) -> bool {
 617        false
 618    }
 619
 620    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 621        LanguageModelToolSchemaFormat::JsonSchema
 622    }
 623
 624    fn max_token_count(&self) -> u64;
 625    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
 626    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 627        None
 628    }
 629    fn max_output_tokens(&self) -> Option<u64> {
 630        None
 631    }
 632
 633    fn count_tokens(
 634        &self,
 635        request: LanguageModelRequest,
 636        cx: &App,
 637    ) -> BoxFuture<'static, Result<u64>>;
 638
 639    fn stream_completion(
 640        &self,
 641        request: LanguageModelRequest,
 642        cx: &AsyncApp,
 643    ) -> BoxFuture<
 644        'static,
 645        Result<
 646            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 647            LanguageModelCompletionError,
 648        >,
 649    >;
 650
 651    fn stream_completion_text(
 652        &self,
 653        request: LanguageModelRequest,
 654        cx: &AsyncApp,
 655    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
 656        let future = self.stream_completion(request, cx);
 657
 658        async move {
 659            let events = future.await?;
 660            let mut events = events.fuse();
 661            let mut message_id = None;
 662            let mut first_item_text = None;
 663            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 664
 665            if let Some(first_event) = events.next().await {
 666                match first_event {
 667                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
 668                        message_id = Some(id);
 669                    }
 670                    Ok(LanguageModelCompletionEvent::Text(text)) => {
 671                        first_item_text = Some(text);
 672                    }
 673                    _ => (),
 674                }
 675            }
 676
 677            let stream = futures::stream::iter(first_item_text.map(Ok))
 678                .chain(events.filter_map({
 679                    let last_token_usage = last_token_usage.clone();
 680                    move |result| {
 681                        let last_token_usage = last_token_usage.clone();
 682                        async move {
 683                            match result {
 684                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
 685                                Ok(LanguageModelCompletionEvent::Started) => None,
 686                                Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None,
 687                                Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
 688                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
 689                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
 690                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
 691                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
 692                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
 693                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
 694                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
 695                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 696                                    ..
 697                                }) => None,
 698                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
 699                                    *last_token_usage.lock() = token_usage;
 700                                    None
 701                                }
 702                                Err(err) => Some(Err(err)),
 703                            }
 704                        }
 705                    }
 706                }))
 707                .boxed();
 708
 709            Ok(LanguageModelTextStream {
 710                message_id,
 711                stream,
 712                last_token_usage,
 713            })
 714        }
 715        .boxed()
 716    }
 717
 718    fn stream_completion_tool(
 719        &self,
 720        request: LanguageModelRequest,
 721        cx: &AsyncApp,
 722    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
 723        let future = self.stream_completion(request, cx);
 724
 725        async move {
 726            let events = future.await?;
 727            let mut events = events.fuse();
 728
 729            // Iterate through events until we find a complete ToolUse
 730            while let Some(event) = events.next().await {
 731                match event {
 732                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
 733                        if tool_use.is_input_complete =>
 734                    {
 735                        return Ok(tool_use);
 736                    }
 737                    Err(err) => {
 738                        return Err(err);
 739                    }
 740                    _ => {}
 741                }
 742            }
 743
 744            // Stream ended without a complete tool use
 745            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
 746                "Stream ended without receiving a complete tool use"
 747            )))
 748        }
 749        .boxed()
 750    }
 751
 752    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 753        None
 754    }
 755
 756    #[cfg(any(test, feature = "test-support"))]
 757    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
 758        unimplemented!()
 759    }
 760}
 761
 762pub trait LanguageModelExt: LanguageModel {
 763    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
 764        match mode {
 765            CompletionMode::Normal => self.max_token_count(),
 766            CompletionMode::Max => self
 767                .max_token_count_in_burn_mode()
 768                .unwrap_or_else(|| self.max_token_count()),
 769        }
 770    }
 771}
 772impl LanguageModelExt for dyn LanguageModel {}
 773
 774impl std::fmt::Debug for dyn LanguageModel {
 775    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 776        f.debug_struct("<dyn LanguageModel>")
 777            .field("id", &self.id())
 778            .field("name", &self.name())
 779            .field("provider_id", &self.provider_id())
 780            .field("provider_name", &self.provider_name())
 781            .field("upstream_provider_name", &self.upstream_provider_name())
 782            .field("upstream_provider_id", &self.upstream_provider_id())
 783            .field("upstream_provider_id", &self.upstream_provider_id())
 784            .field("supports_streaming_tools", &self.supports_streaming_tools())
 785            .finish()
 786    }
 787}
 788
 789/// An error that occurred when trying to authenticate the language model provider.
 790#[derive(Debug, Error)]
 791pub enum AuthenticateError {
 792    #[error("connection refused")]
 793    ConnectionRefused,
 794    #[error("credentials not found")]
 795    CredentialsNotFound,
 796    #[error(transparent)]
 797    Other(#[from] anyhow::Error),
 798}
 799
 800/// Either a built-in icon name or a path to an external SVG.
 801#[derive(Debug, Clone, PartialEq, Eq)]
 802pub enum IconOrSvg {
 803    /// A built-in icon from Zed's icon set.
 804    Icon(IconName),
 805    /// Path to a custom SVG icon file.
 806    Svg(SharedString),
 807}
 808
 809impl Default for IconOrSvg {
 810    fn default() -> Self {
 811        Self::Icon(IconName::ZedAssistant)
 812    }
 813}
 814
 815pub trait LanguageModelProvider: 'static {
 816    fn id(&self) -> LanguageModelProviderId;
 817    fn name(&self) -> LanguageModelProviderName;
 818    fn icon(&self) -> IconOrSvg {
 819        IconOrSvg::default()
 820    }
 821    /// Returns the path to an external SVG icon for this provider, if any.
 822    /// When present, this takes precedence over `icon()`.
 823    fn icon_path(&self) -> Option<SharedString> {
 824        None
 825    }
 826    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 827    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 828    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
 829    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 830        Vec::new()
 831    }
 832    fn is_authenticated(&self, cx: &App) -> bool;
 833    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
 834    fn configuration_view(
 835        &self,
 836        target_agent: ConfigurationViewTargetAgent,
 837        window: &mut Window,
 838        cx: &mut App,
 839    ) -> AnyView;
 840    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
 841}
 842
 843#[derive(Default, Clone, PartialEq, Eq)]
 844pub enum ConfigurationViewTargetAgent {
 845    #[default]
 846    ZedAgent,
 847    EditPrediction,
 848    Other(SharedString),
 849}
 850
 851#[derive(PartialEq, Eq)]
 852pub enum LanguageModelProviderTosView {
 853    /// When there are some past interactions in the Agent Panel.
 854    ThreadEmptyState,
 855    /// When there are no past interactions in the Agent Panel.
 856    ThreadFreshStart,
 857    TextThreadPopup,
 858    Configuration,
 859}
 860
 861pub trait LanguageModelProviderState: 'static {
 862    type ObservableEntity;
 863
 864    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
 865
 866    fn subscribe<T: 'static>(
 867        &self,
 868        cx: &mut gpui::Context<T>,
 869        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
 870    ) -> Option<gpui::Subscription> {
 871        let entity = self.observable_entity()?;
 872        Some(cx.observe(&entity, move |this, _, cx| {
 873            callback(this, cx);
 874        }))
 875    }
 876}
 877
 878#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
 879pub struct LanguageModelId(pub SharedString);
 880
 881#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 882pub struct LanguageModelName(pub SharedString);
 883
 884#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 885pub struct LanguageModelProviderId(pub SharedString);
 886
 887#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 888pub struct LanguageModelProviderName(pub SharedString);
 889
 890impl LanguageModelProviderId {
 891    pub const fn new(id: &'static str) -> Self {
 892        Self(SharedString::new_static(id))
 893    }
 894}
 895
 896impl LanguageModelProviderName {
 897    pub const fn new(id: &'static str) -> Self {
 898        Self(SharedString::new_static(id))
 899    }
 900}
 901
 902impl fmt::Display for LanguageModelProviderId {
 903    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 904        write!(f, "{}", self.0)
 905    }
 906}
 907
 908impl fmt::Display for LanguageModelProviderName {
 909    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 910        write!(f, "{}", self.0)
 911    }
 912}
 913
 914impl From<String> for LanguageModelId {
 915    fn from(value: String) -> Self {
 916        Self(SharedString::from(value))
 917    }
 918}
 919
 920impl From<String> for LanguageModelName {
 921    fn from(value: String) -> Self {
 922        Self(SharedString::from(value))
 923    }
 924}
 925
 926impl From<String> for LanguageModelProviderId {
 927    fn from(value: String) -> Self {
 928        Self(SharedString::from(value))
 929    }
 930}
 931
 932impl From<String> for LanguageModelProviderName {
 933    fn from(value: String) -> Self {
 934        Self(SharedString::from(value))
 935    }
 936}
 937
 938impl From<Arc<str>> for LanguageModelProviderId {
 939    fn from(value: Arc<str>) -> Self {
 940        Self(SharedString::from(value))
 941    }
 942}
 943
 944impl From<Arc<str>> for LanguageModelProviderName {
 945    fn from(value: Arc<str>) -> Self {
 946        Self(SharedString::from(value))
 947    }
 948}
 949
 950#[cfg(test)]
 951mod tests {
 952    use super::*;
 953
 954    #[test]
 955    fn test_from_cloud_failure_with_upstream_http_error() {
 956        let error = LanguageModelCompletionError::from_cloud_failure(
 957            String::from("anthropic").into(),
 958            "upstream_http_error".to_string(),
 959            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 960            None,
 961        );
 962
 963        match error {
 964            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 965                assert_eq!(provider.0, "anthropic");
 966            }
 967            _ => panic!(
 968                "Expected ServerOverloaded error for 503 status, got: {:?}",
 969                error
 970            ),
 971        }
 972
 973        let error = LanguageModelCompletionError::from_cloud_failure(
 974            String::from("anthropic").into(),
 975            "upstream_http_error".to_string(),
 976            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
 977            None,
 978        );
 979
 980        match error {
 981            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 982                assert_eq!(provider.0, "anthropic");
 983                assert_eq!(message, "Internal server error");
 984            }
 985            _ => panic!(
 986                "Expected ApiInternalServerError for 500 status, got: {:?}",
 987                error
 988            ),
 989        }
 990    }
 991
 992    #[test]
 993    fn test_from_cloud_failure_with_standard_format() {
 994        let error = LanguageModelCompletionError::from_cloud_failure(
 995            String::from("anthropic").into(),
 996            "upstream_http_503".to_string(),
 997            "Service unavailable".to_string(),
 998            None,
 999        );
1000
1001        match error {
1002            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1003                assert_eq!(provider.0, "anthropic");
1004            }
1005            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
1006        }
1007    }
1008
1009    #[test]
1010    fn test_upstream_http_error_connection_timeout() {
1011        let error = LanguageModelCompletionError::from_cloud_failure(
1012            String::from("anthropic").into(),
1013            "upstream_http_error".to_string(),
1014            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1015            None,
1016        );
1017
1018        match error {
1019            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1020                assert_eq!(provider.0, "anthropic");
1021            }
1022            _ => panic!(
1023                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
1024                error
1025            ),
1026        }
1027
1028        let error = LanguageModelCompletionError::from_cloud_failure(
1029            String::from("anthropic").into(),
1030            "upstream_http_error".to_string(),
1031            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
1032            None,
1033        );
1034
1035        match error {
1036            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1037                assert_eq!(provider.0, "anthropic");
1038                assert_eq!(
1039                    message,
1040                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
1041                );
1042            }
1043            _ => panic!(
1044                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1045                error
1046            ),
1047        }
1048    }
1049
1050    #[test]
1051    fn test_language_model_tool_use_serializes_with_signature() {
1052        use serde_json::json;
1053
1054        let tool_use = LanguageModelToolUse {
1055            id: LanguageModelToolUseId::from("test_id"),
1056            name: "test_tool".into(),
1057            raw_input: json!({"arg": "value"}).to_string(),
1058            input: json!({"arg": "value"}),
1059            is_input_complete: true,
1060            thought_signature: Some("test_signature".to_string()),
1061        };
1062
1063        let serialized = serde_json::to_value(&tool_use).unwrap();
1064
1065        assert_eq!(serialized["id"], "test_id");
1066        assert_eq!(serialized["name"], "test_tool");
1067        assert_eq!(serialized["thought_signature"], "test_signature");
1068    }
1069
1070    #[test]
1071    fn test_language_model_tool_use_deserializes_with_missing_signature() {
1072        use serde_json::json;
1073
1074        let json = json!({
1075            "id": "test_id",
1076            "name": "test_tool",
1077            "raw_input": "{\"arg\":\"value\"}",
1078            "input": {"arg": "value"},
1079            "is_input_complete": true
1080        });
1081
1082        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1083
1084        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1085        assert_eq!(tool_use.name.as_ref(), "test_tool");
1086        assert_eq!(tool_use.thought_signature, None);
1087    }
1088
1089    #[test]
1090    fn test_language_model_tool_use_round_trip_with_signature() {
1091        use serde_json::json;
1092
1093        let original = LanguageModelToolUse {
1094            id: LanguageModelToolUseId::from("round_trip_id"),
1095            name: "round_trip_tool".into(),
1096            raw_input: json!({"key": "value"}).to_string(),
1097            input: json!({"key": "value"}),
1098            is_input_complete: true,
1099            thought_signature: Some("round_trip_sig".to_string()),
1100        };
1101
1102        let serialized = serde_json::to_value(&original).unwrap();
1103        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1104
1105        assert_eq!(deserialized.id, original.id);
1106        assert_eq!(deserialized.name, original.name);
1107        assert_eq!(deserialized.thought_signature, original.thought_signature);
1108    }
1109
1110    #[test]
1111    fn test_language_model_tool_use_round_trip_without_signature() {
1112        use serde_json::json;
1113
1114        let original = LanguageModelToolUse {
1115            id: LanguageModelToolUseId::from("no_sig_id"),
1116            name: "no_sig_tool".into(),
1117            raw_input: json!({"arg": "value"}).to_string(),
1118            input: json!({"arg": "value"}),
1119            is_input_complete: true,
1120            thought_signature: None,
1121        };
1122
1123        let serialized = serde_json::to_value(&original).unwrap();
1124        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1125
1126        assert_eq!(deserialized.id, original.id);
1127        assert_eq!(deserialized.name, original.name);
1128        assert_eq!(deserialized.thought_signature, None);
1129    }
1130}