language_model.rs

   1mod api_key;
   2mod model;
   3mod rate_limiter;
   4mod registry;
   5mod request;
   6mod role;
   7mod telemetry;
   8pub mod tool_schema;
   9
  10#[cfg(any(test, feature = "test-support"))]
  11pub mod fake_provider;
  12
  13use anthropic::{AnthropicError, parse_prompt_too_long};
  14use anyhow::{Result, anyhow};
  15use client::Client;
  16use cloud_llm_client::CompletionRequestStatus;
  17use futures::FutureExt;
  18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
  19use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
  20use http_client::{StatusCode, http};
  21use icons::IconName;
  22use open_router::OpenRouterError;
  23use parking_lot::Mutex;
  24use serde::{Deserialize, Serialize};
  25pub use settings::LanguageModelCacheConfiguration;
  26use std::ops::{Add, Sub};
  27use std::str::FromStr;
  28use std::sync::Arc;
  29use std::time::Duration;
  30use std::{fmt, io};
  31use thiserror::Error;
  32use util::serde::is_default;
  33
  34pub use crate::api_key::{ApiKey, ApiKeyState};
  35pub use crate::model::*;
  36pub use crate::rate_limiter::*;
  37pub use crate::registry::*;
  38pub use crate::request::*;
  39pub use crate::role::*;
  40pub use crate::telemetry::*;
  41pub use crate::tool_schema::LanguageModelToolSchemaFormat;
  42pub use zed_env_vars::{EnvVar, env_var};
  43
  44pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
  45    LanguageModelProviderId::new("anthropic");
  46pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
  47    LanguageModelProviderName::new("Anthropic");
  48
  49pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
  50pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
  51    LanguageModelProviderName::new("Google AI");
  52
  53pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
  54pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
  55    LanguageModelProviderName::new("OpenAI");
  56
  57pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
  58pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
  59
  60pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
  61pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
  62    LanguageModelProviderName::new("Zed");
  63
  64pub fn init(client: Arc<Client>, cx: &mut App) {
  65    init_settings(cx);
  66    RefreshLlmTokenListener::register(client, cx);
  67}
  68
  69pub fn init_settings(cx: &mut App) {
  70    registry::init(cx);
  71}
  72
  73/// A completion event from a language model.
  74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
  75pub enum LanguageModelCompletionEvent {
  76    Queued {
  77        position: usize,
  78    },
  79    Started,
  80    Stop(StopReason),
  81    Text(String),
  82    Thinking {
  83        text: String,
  84        signature: Option<String>,
  85    },
  86    RedactedThinking {
  87        data: String,
  88    },
  89    ToolUse(LanguageModelToolUse),
  90    ToolUseJsonParseError {
  91        id: LanguageModelToolUseId,
  92        tool_name: Arc<str>,
  93        raw_input: Arc<str>,
  94        json_parse_error: String,
  95    },
  96    StartMessage {
  97        message_id: String,
  98    },
  99    ReasoningDetails(serde_json::Value),
 100    UsageUpdate(TokenUsage),
 101}
 102
 103impl LanguageModelCompletionEvent {
 104    pub fn from_completion_request_status(
 105        status: CompletionRequestStatus,
 106        upstream_provider: LanguageModelProviderName,
 107    ) -> Result<Option<Self>, LanguageModelCompletionError> {
 108        match status {
 109            CompletionRequestStatus::Queued { position } => {
 110                Ok(Some(LanguageModelCompletionEvent::Queued { position }))
 111            }
 112            CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
 113            CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
 114            CompletionRequestStatus::UsageUpdated { .. }
 115            | CompletionRequestStatus::ToolUseLimitReached => Err(
 116                LanguageModelCompletionError::Other(anyhow!("Unexpected status: {status:?}")),
 117            ),
 118            CompletionRequestStatus::Failed {
 119                code,
 120                message,
 121                request_id: _,
 122                retry_after,
 123            } => Err(LanguageModelCompletionError::from_cloud_failure(
 124                upstream_provider,
 125                code,
 126                message,
 127                retry_after.map(Duration::from_secs_f64),
 128            )),
 129        }
 130    }
 131}
 132
 133#[derive(Error, Debug)]
 134pub enum LanguageModelCompletionError {
 135    #[error("prompt too large for context window")]
 136    PromptTooLarge { tokens: Option<u64> },
 137    #[error("missing {provider} API key")]
 138    NoApiKey { provider: LanguageModelProviderName },
 139    #[error("{provider}'s API rate limit exceeded")]
 140    RateLimitExceeded {
 141        provider: LanguageModelProviderName,
 142        retry_after: Option<Duration>,
 143    },
 144    #[error("{provider}'s API servers are overloaded right now")]
 145    ServerOverloaded {
 146        provider: LanguageModelProviderName,
 147        retry_after: Option<Duration>,
 148    },
 149    #[error("{provider}'s API server reported an internal server error: {message}")]
 150    ApiInternalServerError {
 151        provider: LanguageModelProviderName,
 152        message: String,
 153    },
 154    #[error("{message}")]
 155    UpstreamProviderError {
 156        message: String,
 157        status: StatusCode,
 158        retry_after: Option<Duration>,
 159    },
 160    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
 161    HttpResponseError {
 162        provider: LanguageModelProviderName,
 163        status_code: StatusCode,
 164        message: String,
 165    },
 166
 167    // Client errors
 168    #[error("invalid request format to {provider}'s API: {message}")]
 169    BadRequestFormat {
 170        provider: LanguageModelProviderName,
 171        message: String,
 172    },
 173    #[error("authentication error with {provider}'s API: {message}")]
 174    AuthenticationError {
 175        provider: LanguageModelProviderName,
 176        message: String,
 177    },
 178    #[error("Permission error with {provider}'s API: {message}")]
 179    PermissionError {
 180        provider: LanguageModelProviderName,
 181        message: String,
 182    },
 183    #[error("language model provider API endpoint not found")]
 184    ApiEndpointNotFound { provider: LanguageModelProviderName },
 185    #[error("I/O error reading response from {provider}'s API")]
 186    ApiReadResponseError {
 187        provider: LanguageModelProviderName,
 188        #[source]
 189        error: io::Error,
 190    },
 191    #[error("error serializing request to {provider} API")]
 192    SerializeRequest {
 193        provider: LanguageModelProviderName,
 194        #[source]
 195        error: serde_json::Error,
 196    },
 197    #[error("error building request body to {provider} API")]
 198    BuildRequestBody {
 199        provider: LanguageModelProviderName,
 200        #[source]
 201        error: http::Error,
 202    },
 203    #[error("error sending HTTP request to {provider} API")]
 204    HttpSend {
 205        provider: LanguageModelProviderName,
 206        #[source]
 207        error: anyhow::Error,
 208    },
 209    #[error("error deserializing {provider} API response")]
 210    DeserializeResponse {
 211        provider: LanguageModelProviderName,
 212        #[source]
 213        error: serde_json::Error,
 214    },
 215
 216    #[error("stream from {provider} ended unexpectedly")]
 217    StreamEndedUnexpectedly { provider: LanguageModelProviderName },
 218
 219    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
 220    #[error(transparent)]
 221    Other(#[from] anyhow::Error),
 222}
 223
 224impl LanguageModelCompletionError {
 225    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
 226        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
 227        let upstream_status = error_json
 228            .get("upstream_status")
 229            .and_then(|v| v.as_u64())
 230            .and_then(|status| u16::try_from(status).ok())
 231            .and_then(|status| StatusCode::from_u16(status).ok())?;
 232        let inner_message = error_json
 233            .get("message")
 234            .and_then(|v| v.as_str())
 235            .unwrap_or(message)
 236            .to_string();
 237        Some((upstream_status, inner_message))
 238    }
 239
 240    pub fn from_cloud_failure(
 241        upstream_provider: LanguageModelProviderName,
 242        code: String,
 243        message: String,
 244        retry_after: Option<Duration>,
 245    ) -> Self {
 246        if let Some(tokens) = parse_prompt_too_long(&message) {
 247            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
 248            // to be reported. This is a temporary workaround to handle this in the case where the
 249            // token limit has been exceeded.
 250            Self::PromptTooLarge {
 251                tokens: Some(tokens),
 252            }
 253        } else if code == "upstream_http_error" {
 254            if let Some((upstream_status, inner_message)) =
 255                Self::parse_upstream_error_json(&message)
 256            {
 257                return Self::from_http_status(
 258                    upstream_provider,
 259                    upstream_status,
 260                    inner_message,
 261                    retry_after,
 262                );
 263            }
 264            anyhow!("completion request failed, code: {code}, message: {message}").into()
 265        } else if let Some(status_code) = code
 266            .strip_prefix("upstream_http_")
 267            .and_then(|code| StatusCode::from_str(code).ok())
 268        {
 269            Self::from_http_status(upstream_provider, status_code, message, retry_after)
 270        } else if let Some(status_code) = code
 271            .strip_prefix("http_")
 272            .and_then(|code| StatusCode::from_str(code).ok())
 273        {
 274            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
 275        } else {
 276            anyhow!("completion request failed, code: {code}, message: {message}").into()
 277        }
 278    }
 279
 280    pub fn from_http_status(
 281        provider: LanguageModelProviderName,
 282        status_code: StatusCode,
 283        message: String,
 284        retry_after: Option<Duration>,
 285    ) -> Self {
 286        match status_code {
 287            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
 288            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
 289            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
 290            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
 291            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
 292                tokens: parse_prompt_too_long(&message),
 293            },
 294            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
 295                provider,
 296                retry_after,
 297            },
 298            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
 299            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
 300                provider,
 301                retry_after,
 302            },
 303            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
 304                provider,
 305                retry_after,
 306            },
 307            _ => Self::HttpResponseError {
 308                provider,
 309                status_code,
 310                message,
 311            },
 312        }
 313    }
 314}
 315
 316impl From<AnthropicError> for LanguageModelCompletionError {
 317    fn from(error: AnthropicError) -> Self {
 318        let provider = ANTHROPIC_PROVIDER_NAME;
 319        match error {
 320            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 321            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 322            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
 323            AnthropicError::DeserializeResponse(error) => {
 324                Self::DeserializeResponse { provider, error }
 325            }
 326            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 327            AnthropicError::HttpResponseError {
 328                status_code,
 329                message,
 330            } => Self::HttpResponseError {
 331                provider,
 332                status_code,
 333                message,
 334            },
 335            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
 336                provider,
 337                retry_after: Some(retry_after),
 338            },
 339            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 340                provider,
 341                retry_after,
 342            },
 343            AnthropicError::ApiError(api_error) => api_error.into(),
 344        }
 345    }
 346}
 347
 348impl From<anthropic::ApiError> for LanguageModelCompletionError {
 349    fn from(error: anthropic::ApiError) -> Self {
 350        use anthropic::ApiErrorCode::*;
 351        let provider = ANTHROPIC_PROVIDER_NAME;
 352        match error.code() {
 353            Some(code) => match code {
 354                InvalidRequestError => Self::BadRequestFormat {
 355                    provider,
 356                    message: error.message,
 357                },
 358                AuthenticationError => Self::AuthenticationError {
 359                    provider,
 360                    message: error.message,
 361                },
 362                PermissionError => Self::PermissionError {
 363                    provider,
 364                    message: error.message,
 365                },
 366                NotFoundError => Self::ApiEndpointNotFound { provider },
 367                RequestTooLarge => Self::PromptTooLarge {
 368                    tokens: parse_prompt_too_long(&error.message),
 369                },
 370                RateLimitError => Self::RateLimitExceeded {
 371                    provider,
 372                    retry_after: None,
 373                },
 374                ApiError => Self::ApiInternalServerError {
 375                    provider,
 376                    message: error.message,
 377                },
 378                OverloadedError => Self::ServerOverloaded {
 379                    provider,
 380                    retry_after: None,
 381                },
 382            },
 383            None => Self::Other(error.into()),
 384        }
 385    }
 386}
 387
 388impl From<open_ai::RequestError> for LanguageModelCompletionError {
 389    fn from(error: open_ai::RequestError) -> Self {
 390        match error {
 391            open_ai::RequestError::HttpResponseError {
 392                provider,
 393                status_code,
 394                body,
 395                headers,
 396            } => {
 397                let retry_after = headers
 398                    .get(http::header::RETRY_AFTER)
 399                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
 400                    .map(Duration::from_secs);
 401
 402                Self::from_http_status(provider.into(), status_code, body, retry_after)
 403            }
 404            open_ai::RequestError::Other(e) => Self::Other(e),
 405        }
 406    }
 407}
 408
 409impl From<OpenRouterError> for LanguageModelCompletionError {
 410    fn from(error: OpenRouterError) -> Self {
 411        let provider = LanguageModelProviderName::new("OpenRouter");
 412        match error {
 413            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 414            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 415            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
 416            OpenRouterError::DeserializeResponse(error) => {
 417                Self::DeserializeResponse { provider, error }
 418            }
 419            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 420            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
 421                provider,
 422                retry_after: Some(retry_after),
 423            },
 424            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 425                provider,
 426                retry_after,
 427            },
 428            OpenRouterError::ApiError(api_error) => api_error.into(),
 429        }
 430    }
 431}
 432
 433impl From<open_router::ApiError> for LanguageModelCompletionError {
 434    fn from(error: open_router::ApiError) -> Self {
 435        use open_router::ApiErrorCode::*;
 436        let provider = LanguageModelProviderName::new("OpenRouter");
 437        match error.code {
 438            InvalidRequestError => Self::BadRequestFormat {
 439                provider,
 440                message: error.message,
 441            },
 442            AuthenticationError => Self::AuthenticationError {
 443                provider,
 444                message: error.message,
 445            },
 446            PaymentRequiredError => Self::AuthenticationError {
 447                provider,
 448                message: format!("Payment required: {}", error.message),
 449            },
 450            PermissionError => Self::PermissionError {
 451                provider,
 452                message: error.message,
 453            },
 454            RequestTimedOut => Self::HttpResponseError {
 455                provider,
 456                status_code: StatusCode::REQUEST_TIMEOUT,
 457                message: error.message,
 458            },
 459            RateLimitError => Self::RateLimitExceeded {
 460                provider,
 461                retry_after: None,
 462            },
 463            ApiError => Self::ApiInternalServerError {
 464                provider,
 465                message: error.message,
 466            },
 467            OverloadedError => Self::ServerOverloaded {
 468                provider,
 469                retry_after: None,
 470            },
 471        }
 472    }
 473}
 474
 475#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 476#[serde(rename_all = "snake_case")]
 477pub enum StopReason {
 478    EndTurn,
 479    MaxTokens,
 480    ToolUse,
 481    Refusal,
 482}
 483
 484#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 485pub struct TokenUsage {
 486    #[serde(default, skip_serializing_if = "is_default")]
 487    pub input_tokens: u64,
 488    #[serde(default, skip_serializing_if = "is_default")]
 489    pub output_tokens: u64,
 490    #[serde(default, skip_serializing_if = "is_default")]
 491    pub cache_creation_input_tokens: u64,
 492    #[serde(default, skip_serializing_if = "is_default")]
 493    pub cache_read_input_tokens: u64,
 494}
 495
 496impl TokenUsage {
 497    pub fn total_tokens(&self) -> u64 {
 498        self.input_tokens
 499            + self.output_tokens
 500            + self.cache_read_input_tokens
 501            + self.cache_creation_input_tokens
 502    }
 503}
 504
 505impl Add<TokenUsage> for TokenUsage {
 506    type Output = Self;
 507
 508    fn add(self, other: Self) -> Self {
 509        Self {
 510            input_tokens: self.input_tokens + other.input_tokens,
 511            output_tokens: self.output_tokens + other.output_tokens,
 512            cache_creation_input_tokens: self.cache_creation_input_tokens
 513                + other.cache_creation_input_tokens,
 514            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
 515        }
 516    }
 517}
 518
 519impl Sub<TokenUsage> for TokenUsage {
 520    type Output = Self;
 521
 522    fn sub(self, other: Self) -> Self {
 523        Self {
 524            input_tokens: self.input_tokens - other.input_tokens,
 525            output_tokens: self.output_tokens - other.output_tokens,
 526            cache_creation_input_tokens: self.cache_creation_input_tokens
 527                - other.cache_creation_input_tokens,
 528            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
 529        }
 530    }
 531}
 532
 533#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 534pub struct LanguageModelToolUseId(Arc<str>);
 535
 536impl fmt::Display for LanguageModelToolUseId {
 537    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 538        write!(f, "{}", self.0)
 539    }
 540}
 541
 542impl<T> From<T> for LanguageModelToolUseId
 543where
 544    T: Into<Arc<str>>,
 545{
 546    fn from(value: T) -> Self {
 547        Self(value.into())
 548    }
 549}
 550
 551#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 552pub struct LanguageModelToolUse {
 553    pub id: LanguageModelToolUseId,
 554    pub name: Arc<str>,
 555    pub raw_input: String,
 556    pub input: serde_json::Value,
 557    pub is_input_complete: bool,
 558    /// Thought signature the model sent us. Some models require that this
 559    /// signature be preserved and sent back in conversation history for validation.
 560    pub thought_signature: Option<String>,
 561}
 562
 563pub struct LanguageModelTextStream {
 564    pub message_id: Option<String>,
 565    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 566    // Has complete token usage after the stream has finished
 567    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 568}
 569
 570impl Default for LanguageModelTextStream {
 571    fn default() -> Self {
 572        Self {
 573            message_id: None,
 574            stream: Box::pin(futures::stream::empty()),
 575            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 576        }
 577    }
 578}
 579
 580#[derive(Debug, Clone)]
 581pub struct LanguageModelEffortLevel {
 582    pub name: SharedString,
 583    pub value: SharedString,
 584    pub is_default: bool,
 585}
 586
 587pub trait LanguageModel: Send + Sync {
 588    fn id(&self) -> LanguageModelId;
 589    fn name(&self) -> LanguageModelName;
 590    fn provider_id(&self) -> LanguageModelProviderId;
 591    fn provider_name(&self) -> LanguageModelProviderName;
 592    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 593        self.provider_id()
 594    }
 595    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 596        self.provider_name()
 597    }
 598
 599    /// Returns whether this model is the "latest", so we can highlight it in the UI.
 600    fn is_latest(&self) -> bool {
 601        false
 602    }
 603
 604    fn telemetry_id(&self) -> String;
 605
 606    fn api_key(&self, _cx: &App) -> Option<String> {
 607        None
 608    }
 609
 610    /// Whether this model supports thinking.
 611    fn supports_thinking(&self) -> bool {
 612        false
 613    }
 614
 615    /// Returns the list of supported effort levels that can be used when thinking.
 616    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
 617        Vec::new()
 618    }
 619
 620    /// Returns the default effort level to use when thinking.
 621    fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
 622        self.supported_effort_levels()
 623            .into_iter()
 624            .find(|effort_level| effort_level.is_default)
 625    }
 626
 627    /// Whether this model supports images
 628    fn supports_images(&self) -> bool;
 629
 630    /// Whether this model supports tools.
 631    fn supports_tools(&self) -> bool;
 632
 633    /// Whether this model supports choosing which tool to use.
 634    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
 635
 636    /// Returns whether this model or provider supports streaming tool calls;
 637    fn supports_streaming_tools(&self) -> bool {
 638        false
 639    }
 640
 641    /// Returns whether this model/provider reports accurate split input/output token counts.
 642    /// When true, the UI may show separate input/output token indicators.
 643    fn supports_split_token_display(&self) -> bool {
 644        false
 645    }
 646
 647    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 648        LanguageModelToolSchemaFormat::JsonSchema
 649    }
 650
 651    fn max_token_count(&self) -> u64;
 652    fn max_output_tokens(&self) -> Option<u64> {
 653        None
 654    }
 655
 656    fn count_tokens(
 657        &self,
 658        request: LanguageModelRequest,
 659        cx: &App,
 660    ) -> BoxFuture<'static, Result<u64>>;
 661
 662    fn stream_completion(
 663        &self,
 664        request: LanguageModelRequest,
 665        cx: &AsyncApp,
 666    ) -> BoxFuture<
 667        'static,
 668        Result<
 669            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 670            LanguageModelCompletionError,
 671        >,
 672    >;
 673
 674    fn stream_completion_text(
 675        &self,
 676        request: LanguageModelRequest,
 677        cx: &AsyncApp,
 678    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
 679        let future = self.stream_completion(request, cx);
 680
 681        async move {
 682            let events = future.await?;
 683            let mut events = events.fuse();
 684            let mut message_id = None;
 685            let mut first_item_text = None;
 686            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 687
 688            if let Some(first_event) = events.next().await {
 689                match first_event {
 690                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
 691                        message_id = Some(id);
 692                    }
 693                    Ok(LanguageModelCompletionEvent::Text(text)) => {
 694                        first_item_text = Some(text);
 695                    }
 696                    _ => (),
 697                }
 698            }
 699
 700            let stream = futures::stream::iter(first_item_text.map(Ok))
 701                .chain(events.filter_map({
 702                    let last_token_usage = last_token_usage.clone();
 703                    move |result| {
 704                        let last_token_usage = last_token_usage.clone();
 705                        async move {
 706                            match result {
 707                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
 708                                Ok(LanguageModelCompletionEvent::Started) => None,
 709                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
 710                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
 711                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
 712                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
 713                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
 714                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
 715                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
 716                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 717                                    ..
 718                                }) => None,
 719                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
 720                                    *last_token_usage.lock() = token_usage;
 721                                    None
 722                                }
 723                                Err(err) => Some(Err(err)),
 724                            }
 725                        }
 726                    }
 727                }))
 728                .boxed();
 729
 730            Ok(LanguageModelTextStream {
 731                message_id,
 732                stream,
 733                last_token_usage,
 734            })
 735        }
 736        .boxed()
 737    }
 738
 739    fn stream_completion_tool(
 740        &self,
 741        request: LanguageModelRequest,
 742        cx: &AsyncApp,
 743    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
 744        let future = self.stream_completion(request, cx);
 745
 746        async move {
 747            let events = future.await?;
 748            let mut events = events.fuse();
 749
 750            // Iterate through events until we find a complete ToolUse
 751            while let Some(event) = events.next().await {
 752                match event {
 753                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
 754                        if tool_use.is_input_complete =>
 755                    {
 756                        return Ok(tool_use);
 757                    }
 758                    Err(err) => {
 759                        return Err(err);
 760                    }
 761                    _ => {}
 762                }
 763            }
 764
 765            // Stream ended without a complete tool use
 766            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
 767                "Stream ended without receiving a complete tool use"
 768            )))
 769        }
 770        .boxed()
 771    }
 772
 773    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 774        None
 775    }
 776
 777    #[cfg(any(test, feature = "test-support"))]
 778    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
 779        unimplemented!()
 780    }
 781}
 782
 783impl std::fmt::Debug for dyn LanguageModel {
 784    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 785        f.debug_struct("<dyn LanguageModel>")
 786            .field("id", &self.id())
 787            .field("name", &self.name())
 788            .field("provider_id", &self.provider_id())
 789            .field("provider_name", &self.provider_name())
 790            .field("upstream_provider_name", &self.upstream_provider_name())
 791            .field("upstream_provider_id", &self.upstream_provider_id())
 792            .field("upstream_provider_id", &self.upstream_provider_id())
 793            .field("supports_streaming_tools", &self.supports_streaming_tools())
 794            .finish()
 795    }
 796}
 797
 798/// An error that occurred when trying to authenticate the language model provider.
 799#[derive(Debug, Error)]
 800pub enum AuthenticateError {
 801    #[error("connection refused")]
 802    ConnectionRefused,
 803    #[error("credentials not found")]
 804    CredentialsNotFound,
 805    #[error(transparent)]
 806    Other(#[from] anyhow::Error),
 807}
 808
 809/// Either a built-in icon name or a path to an external SVG.
 810#[derive(Debug, Clone, PartialEq, Eq)]
 811pub enum IconOrSvg {
 812    /// A built-in icon from Zed's icon set.
 813    Icon(IconName),
 814    /// Path to a custom SVG icon file.
 815    Svg(SharedString),
 816}
 817
 818impl Default for IconOrSvg {
 819    fn default() -> Self {
 820        Self::Icon(IconName::ZedAssistant)
 821    }
 822}
 823
 824pub trait LanguageModelProvider: 'static {
 825    fn id(&self) -> LanguageModelProviderId;
 826    fn name(&self) -> LanguageModelProviderName;
 827    fn icon(&self) -> IconOrSvg {
 828        IconOrSvg::default()
 829    }
 830    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 831    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 832    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
 833    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 834        Vec::new()
 835    }
 836    fn is_authenticated(&self, cx: &App) -> bool;
 837    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
 838    fn configuration_view(
 839        &self,
 840        target_agent: ConfigurationViewTargetAgent,
 841        window: &mut Window,
 842        cx: &mut App,
 843    ) -> AnyView;
 844    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
 845}
 846
 847#[derive(Default, Clone, PartialEq, Eq)]
 848pub enum ConfigurationViewTargetAgent {
 849    #[default]
 850    ZedAgent,
 851    Other(SharedString),
 852}
 853
 854#[derive(PartialEq, Eq)]
 855pub enum LanguageModelProviderTosView {
 856    /// When there are some past interactions in the Agent Panel.
 857    ThreadEmptyState,
 858    /// When there are no past interactions in the Agent Panel.
 859    ThreadFreshStart,
 860    TextThreadPopup,
 861    Configuration,
 862}
 863
 864pub trait LanguageModelProviderState: 'static {
 865    type ObservableEntity;
 866
 867    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
 868
 869    fn subscribe<T: 'static>(
 870        &self,
 871        cx: &mut gpui::Context<T>,
 872        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
 873    ) -> Option<gpui::Subscription> {
 874        let entity = self.observable_entity()?;
 875        Some(cx.observe(&entity, move |this, _, cx| {
 876            callback(this, cx);
 877        }))
 878    }
 879}
 880
 881#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
 882pub struct LanguageModelId(pub SharedString);
 883
 884#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 885pub struct LanguageModelName(pub SharedString);
 886
 887#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 888pub struct LanguageModelProviderId(pub SharedString);
 889
 890#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 891pub struct LanguageModelProviderName(pub SharedString);
 892
 893impl LanguageModelProviderId {
 894    pub const fn new(id: &'static str) -> Self {
 895        Self(SharedString::new_static(id))
 896    }
 897}
 898
 899impl LanguageModelProviderName {
 900    pub const fn new(id: &'static str) -> Self {
 901        Self(SharedString::new_static(id))
 902    }
 903}
 904
 905impl fmt::Display for LanguageModelProviderId {
 906    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 907        write!(f, "{}", self.0)
 908    }
 909}
 910
 911impl fmt::Display for LanguageModelProviderName {
 912    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 913        write!(f, "{}", self.0)
 914    }
 915}
 916
 917impl From<String> for LanguageModelId {
 918    fn from(value: String) -> Self {
 919        Self(SharedString::from(value))
 920    }
 921}
 922
 923impl From<String> for LanguageModelName {
 924    fn from(value: String) -> Self {
 925        Self(SharedString::from(value))
 926    }
 927}
 928
 929impl From<String> for LanguageModelProviderId {
 930    fn from(value: String) -> Self {
 931        Self(SharedString::from(value))
 932    }
 933}
 934
 935impl From<String> for LanguageModelProviderName {
 936    fn from(value: String) -> Self {
 937        Self(SharedString::from(value))
 938    }
 939}
 940
 941impl From<Arc<str>> for LanguageModelProviderId {
 942    fn from(value: Arc<str>) -> Self {
 943        Self(SharedString::from(value))
 944    }
 945}
 946
 947impl From<Arc<str>> for LanguageModelProviderName {
 948    fn from(value: Arc<str>) -> Self {
 949        Self(SharedString::from(value))
 950    }
 951}
 952
 953#[cfg(test)]
 954mod tests {
 955    use super::*;
 956
 957    #[test]
 958    fn test_from_cloud_failure_with_upstream_http_error() {
 959        let error = LanguageModelCompletionError::from_cloud_failure(
 960            String::from("anthropic").into(),
 961            "upstream_http_error".to_string(),
 962            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 963            None,
 964        );
 965
 966        match error {
 967            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 968                assert_eq!(provider.0, "anthropic");
 969            }
 970            _ => panic!(
 971                "Expected ServerOverloaded error for 503 status, got: {:?}",
 972                error
 973            ),
 974        }
 975
 976        let error = LanguageModelCompletionError::from_cloud_failure(
 977            String::from("anthropic").into(),
 978            "upstream_http_error".to_string(),
 979            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
 980            None,
 981        );
 982
 983        match error {
 984            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 985                assert_eq!(provider.0, "anthropic");
 986                assert_eq!(message, "Internal server error");
 987            }
 988            _ => panic!(
 989                "Expected ApiInternalServerError for 500 status, got: {:?}",
 990                error
 991            ),
 992        }
 993    }
 994
 995    #[test]
 996    fn test_from_cloud_failure_with_standard_format() {
 997        let error = LanguageModelCompletionError::from_cloud_failure(
 998            String::from("anthropic").into(),
 999            "upstream_http_503".to_string(),
1000            "Service unavailable".to_string(),
1001            None,
1002        );
1003
1004        match error {
1005            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1006                assert_eq!(provider.0, "anthropic");
1007            }
1008            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
1009        }
1010    }
1011
1012    #[test]
1013    fn test_upstream_http_error_connection_timeout() {
1014        let error = LanguageModelCompletionError::from_cloud_failure(
1015            String::from("anthropic").into(),
1016            "upstream_http_error".to_string(),
1017            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1018            None,
1019        );
1020
1021        match error {
1022            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1023                assert_eq!(provider.0, "anthropic");
1024            }
1025            _ => panic!(
1026                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
1027                error
1028            ),
1029        }
1030
1031        let error = LanguageModelCompletionError::from_cloud_failure(
1032            String::from("anthropic").into(),
1033            "upstream_http_error".to_string(),
1034            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
1035            None,
1036        );
1037
1038        match error {
1039            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1040                assert_eq!(provider.0, "anthropic");
1041                assert_eq!(
1042                    message,
1043                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
1044                );
1045            }
1046            _ => panic!(
1047                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1048                error
1049            ),
1050        }
1051    }
1052
1053    #[test]
1054    fn test_language_model_tool_use_serializes_with_signature() {
1055        use serde_json::json;
1056
1057        let tool_use = LanguageModelToolUse {
1058            id: LanguageModelToolUseId::from("test_id"),
1059            name: "test_tool".into(),
1060            raw_input: json!({"arg": "value"}).to_string(),
1061            input: json!({"arg": "value"}),
1062            is_input_complete: true,
1063            thought_signature: Some("test_signature".to_string()),
1064        };
1065
1066        let serialized = serde_json::to_value(&tool_use).unwrap();
1067
1068        assert_eq!(serialized["id"], "test_id");
1069        assert_eq!(serialized["name"], "test_tool");
1070        assert_eq!(serialized["thought_signature"], "test_signature");
1071    }
1072
1073    #[test]
1074    fn test_language_model_tool_use_deserializes_with_missing_signature() {
1075        use serde_json::json;
1076
1077        let json = json!({
1078            "id": "test_id",
1079            "name": "test_tool",
1080            "raw_input": "{\"arg\":\"value\"}",
1081            "input": {"arg": "value"},
1082            "is_input_complete": true
1083        });
1084
1085        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1086
1087        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1088        assert_eq!(tool_use.name.as_ref(), "test_tool");
1089        assert_eq!(tool_use.thought_signature, None);
1090    }
1091
1092    #[test]
1093    fn test_language_model_tool_use_round_trip_with_signature() {
1094        use serde_json::json;
1095
1096        let original = LanguageModelToolUse {
1097            id: LanguageModelToolUseId::from("round_trip_id"),
1098            name: "round_trip_tool".into(),
1099            raw_input: json!({"key": "value"}).to_string(),
1100            input: json!({"key": "value"}),
1101            is_input_complete: true,
1102            thought_signature: Some("round_trip_sig".to_string()),
1103        };
1104
1105        let serialized = serde_json::to_value(&original).unwrap();
1106        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1107
1108        assert_eq!(deserialized.id, original.id);
1109        assert_eq!(deserialized.name, original.name);
1110        assert_eq!(deserialized.thought_signature, original.thought_signature);
1111    }
1112
1113    #[test]
1114    fn test_language_model_tool_use_round_trip_without_signature() {
1115        use serde_json::json;
1116
1117        let original = LanguageModelToolUse {
1118            id: LanguageModelToolUseId::from("no_sig_id"),
1119            name: "no_sig_tool".into(),
1120            raw_input: json!({"arg": "value"}).to_string(),
1121            input: json!({"arg": "value"}),
1122            is_input_complete: true,
1123            thought_signature: None,
1124        };
1125
1126        let serialized = serde_json::to_value(&original).unwrap();
1127        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1128
1129        assert_eq!(deserialized.id, original.id);
1130        assert_eq!(deserialized.name, original.name);
1131        assert_eq!(deserialized.thought_signature, None);
1132    }
1133}