language_model.rs

   1mod model;
   2mod rate_limiter;
   3mod registry;
   4mod request;
   5mod role;
   6mod telemetry;
   7pub mod tool_schema;
   8
   9#[cfg(any(test, feature = "test-support"))]
  10pub mod fake_provider;
  11
  12use anthropic::{AnthropicError, parse_prompt_too_long};
  13use anyhow::{Result, anyhow};
  14use client::Client;
  15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
  16use futures::FutureExt;
  17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
  18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
  19use http_client::{StatusCode, http};
  20use icons::IconName;
  21use open_router::OpenRouterError;
  22use parking_lot::Mutex;
  23use serde::{Deserialize, Serialize};
  24pub use settings::LanguageModelCacheConfiguration;
  25use std::ops::{Add, Sub};
  26use std::str::FromStr;
  27use std::sync::Arc;
  28use std::time::Duration;
  29use std::{fmt, io};
  30use thiserror::Error;
  31use util::serde::is_default;
  32
  33pub use crate::model::*;
  34pub use crate::rate_limiter::*;
  35pub use crate::registry::*;
  36pub use crate::request::*;
  37pub use crate::role::*;
  38pub use crate::telemetry::*;
  39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
  40
  41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
  42    LanguageModelProviderId::new("anthropic");
  43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
  44    LanguageModelProviderName::new("Anthropic");
  45
  46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
  47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
  48    LanguageModelProviderName::new("Google AI");
  49
  50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
  51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
  52    LanguageModelProviderName::new("OpenAI");
  53
  54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
  55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
  56
  57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
  58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
  59    LanguageModelProviderName::new("Zed");
  60
  61pub fn init(client: Arc<Client>, cx: &mut App) {
  62    init_settings(cx);
  63    RefreshLlmTokenListener::register(client, cx);
  64}
  65
  66pub fn init_settings(cx: &mut App) {
  67    registry::init(cx);
  68}
  69
  70/// A completion event from a language model.
  71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
  72pub enum LanguageModelCompletionEvent {
  73    StatusUpdate(CompletionRequestStatus),
  74    Stop(StopReason),
  75    Text(String),
  76    Thinking {
  77        text: String,
  78        signature: Option<String>,
  79    },
  80    RedactedThinking {
  81        data: String,
  82    },
  83    ToolUse(LanguageModelToolUse),
  84    ToolUseJsonParseError {
  85        id: LanguageModelToolUseId,
  86        tool_name: Arc<str>,
  87        raw_input: Arc<str>,
  88        json_parse_error: String,
  89    },
  90    StartMessage {
  91        message_id: String,
  92    },
  93    ReasoningDetails(serde_json::Value),
  94    UsageUpdate(TokenUsage),
  95}
  96
  97#[derive(Error, Debug)]
  98pub enum LanguageModelCompletionError {
  99    #[error("prompt too large for context window")]
 100    PromptTooLarge { tokens: Option<u64> },
 101    #[error("missing {provider} API key")]
 102    NoApiKey { provider: LanguageModelProviderName },
 103    #[error("{provider}'s API rate limit exceeded")]
 104    RateLimitExceeded {
 105        provider: LanguageModelProviderName,
 106        retry_after: Option<Duration>,
 107    },
 108    #[error("{provider}'s API servers are overloaded right now")]
 109    ServerOverloaded {
 110        provider: LanguageModelProviderName,
 111        retry_after: Option<Duration>,
 112    },
 113    #[error("{provider}'s API server reported an internal server error: {message}")]
 114    ApiInternalServerError {
 115        provider: LanguageModelProviderName,
 116        message: String,
 117    },
 118    #[error("{message}")]
 119    UpstreamProviderError {
 120        message: String,
 121        status: StatusCode,
 122        retry_after: Option<Duration>,
 123    },
 124    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
 125    HttpResponseError {
 126        provider: LanguageModelProviderName,
 127        status_code: StatusCode,
 128        message: String,
 129    },
 130
 131    // Client errors
 132    #[error("invalid request format to {provider}'s API: {message}")]
 133    BadRequestFormat {
 134        provider: LanguageModelProviderName,
 135        message: String,
 136    },
 137    #[error("authentication error with {provider}'s API: {message}")]
 138    AuthenticationError {
 139        provider: LanguageModelProviderName,
 140        message: String,
 141    },
 142    #[error("Permission error with {provider}'s API: {message}")]
 143    PermissionError {
 144        provider: LanguageModelProviderName,
 145        message: String,
 146    },
 147    #[error("language model provider API endpoint not found")]
 148    ApiEndpointNotFound { provider: LanguageModelProviderName },
 149    #[error("I/O error reading response from {provider}'s API")]
 150    ApiReadResponseError {
 151        provider: LanguageModelProviderName,
 152        #[source]
 153        error: io::Error,
 154    },
 155    #[error("error serializing request to {provider} API")]
 156    SerializeRequest {
 157        provider: LanguageModelProviderName,
 158        #[source]
 159        error: serde_json::Error,
 160    },
 161    #[error("error building request body to {provider} API")]
 162    BuildRequestBody {
 163        provider: LanguageModelProviderName,
 164        #[source]
 165        error: http::Error,
 166    },
 167    #[error("error sending HTTP request to {provider} API")]
 168    HttpSend {
 169        provider: LanguageModelProviderName,
 170        #[source]
 171        error: anyhow::Error,
 172    },
 173    #[error("error deserializing {provider} API response")]
 174    DeserializeResponse {
 175        provider: LanguageModelProviderName,
 176        #[source]
 177        error: serde_json::Error,
 178    },
 179
 180    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
 181    #[error(transparent)]
 182    Other(#[from] anyhow::Error),
 183}
 184
 185impl LanguageModelCompletionError {
 186    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
 187        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
 188        let upstream_status = error_json
 189            .get("upstream_status")
 190            .and_then(|v| v.as_u64())
 191            .and_then(|status| u16::try_from(status).ok())
 192            .and_then(|status| StatusCode::from_u16(status).ok())?;
 193        let inner_message = error_json
 194            .get("message")
 195            .and_then(|v| v.as_str())
 196            .unwrap_or(message)
 197            .to_string();
 198        Some((upstream_status, inner_message))
 199    }
 200
 201    pub fn from_cloud_failure(
 202        upstream_provider: LanguageModelProviderName,
 203        code: String,
 204        message: String,
 205        retry_after: Option<Duration>,
 206    ) -> Self {
 207        if let Some(tokens) = parse_prompt_too_long(&message) {
 208            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
 209            // to be reported. This is a temporary workaround to handle this in the case where the
 210            // token limit has been exceeded.
 211            Self::PromptTooLarge {
 212                tokens: Some(tokens),
 213            }
 214        } else if code == "upstream_http_error" {
 215            if let Some((upstream_status, inner_message)) =
 216                Self::parse_upstream_error_json(&message)
 217            {
 218                return Self::from_http_status(
 219                    upstream_provider,
 220                    upstream_status,
 221                    inner_message,
 222                    retry_after,
 223                );
 224            }
 225            anyhow!("completion request failed, code: {code}, message: {message}").into()
 226        } else if let Some(status_code) = code
 227            .strip_prefix("upstream_http_")
 228            .and_then(|code| StatusCode::from_str(code).ok())
 229        {
 230            Self::from_http_status(upstream_provider, status_code, message, retry_after)
 231        } else if let Some(status_code) = code
 232            .strip_prefix("http_")
 233            .and_then(|code| StatusCode::from_str(code).ok())
 234        {
 235            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
 236        } else {
 237            anyhow!("completion request failed, code: {code}, message: {message}").into()
 238        }
 239    }
 240
 241    pub fn from_http_status(
 242        provider: LanguageModelProviderName,
 243        status_code: StatusCode,
 244        message: String,
 245        retry_after: Option<Duration>,
 246    ) -> Self {
 247        match status_code {
 248            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
 249            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
 250            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
 251            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
 252            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
 253                tokens: parse_prompt_too_long(&message),
 254            },
 255            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
 256                provider,
 257                retry_after,
 258            },
 259            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
 260            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
 261                provider,
 262                retry_after,
 263            },
 264            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
 265                provider,
 266                retry_after,
 267            },
 268            _ => Self::HttpResponseError {
 269                provider,
 270                status_code,
 271                message,
 272            },
 273        }
 274    }
 275}
 276
 277impl From<AnthropicError> for LanguageModelCompletionError {
 278    fn from(error: AnthropicError) -> Self {
 279        let provider = ANTHROPIC_PROVIDER_NAME;
 280        match error {
 281            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 282            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 283            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
 284            AnthropicError::DeserializeResponse(error) => {
 285                Self::DeserializeResponse { provider, error }
 286            }
 287            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 288            AnthropicError::HttpResponseError {
 289                status_code,
 290                message,
 291            } => Self::HttpResponseError {
 292                provider,
 293                status_code,
 294                message,
 295            },
 296            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
 297                provider,
 298                retry_after: Some(retry_after),
 299            },
 300            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 301                provider,
 302                retry_after,
 303            },
 304            AnthropicError::ApiError(api_error) => api_error.into(),
 305        }
 306    }
 307}
 308
 309impl From<anthropic::ApiError> for LanguageModelCompletionError {
 310    fn from(error: anthropic::ApiError) -> Self {
 311        use anthropic::ApiErrorCode::*;
 312        let provider = ANTHROPIC_PROVIDER_NAME;
 313        match error.code() {
 314            Some(code) => match code {
 315                InvalidRequestError => Self::BadRequestFormat {
 316                    provider,
 317                    message: error.message,
 318                },
 319                AuthenticationError => Self::AuthenticationError {
 320                    provider,
 321                    message: error.message,
 322                },
 323                PermissionError => Self::PermissionError {
 324                    provider,
 325                    message: error.message,
 326                },
 327                NotFoundError => Self::ApiEndpointNotFound { provider },
 328                RequestTooLarge => Self::PromptTooLarge {
 329                    tokens: parse_prompt_too_long(&error.message),
 330                },
 331                RateLimitError => Self::RateLimitExceeded {
 332                    provider,
 333                    retry_after: None,
 334                },
 335                ApiError => Self::ApiInternalServerError {
 336                    provider,
 337                    message: error.message,
 338                },
 339                OverloadedError => Self::ServerOverloaded {
 340                    provider,
 341                    retry_after: None,
 342                },
 343            },
 344            None => Self::Other(error.into()),
 345        }
 346    }
 347}
 348
 349impl From<open_ai::RequestError> for LanguageModelCompletionError {
 350    fn from(error: open_ai::RequestError) -> Self {
 351        match error {
 352            open_ai::RequestError::HttpResponseError {
 353                provider,
 354                status_code,
 355                body,
 356                headers,
 357            } => {
 358                let retry_after = headers
 359                    .get(http::header::RETRY_AFTER)
 360                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
 361                    .map(Duration::from_secs);
 362
 363                Self::from_http_status(provider.into(), status_code, body, retry_after)
 364            }
 365            open_ai::RequestError::Other(e) => Self::Other(e),
 366        }
 367    }
 368}
 369
 370impl From<OpenRouterError> for LanguageModelCompletionError {
 371    fn from(error: OpenRouterError) -> Self {
 372        let provider = LanguageModelProviderName::new("OpenRouter");
 373        match error {
 374            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 375            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 376            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
 377            OpenRouterError::DeserializeResponse(error) => {
 378                Self::DeserializeResponse { provider, error }
 379            }
 380            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 381            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
 382                provider,
 383                retry_after: Some(retry_after),
 384            },
 385            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 386                provider,
 387                retry_after,
 388            },
 389            OpenRouterError::ApiError(api_error) => api_error.into(),
 390        }
 391    }
 392}
 393
 394impl From<open_router::ApiError> for LanguageModelCompletionError {
 395    fn from(error: open_router::ApiError) -> Self {
 396        use open_router::ApiErrorCode::*;
 397        let provider = LanguageModelProviderName::new("OpenRouter");
 398        match error.code {
 399            InvalidRequestError => Self::BadRequestFormat {
 400                provider,
 401                message: error.message,
 402            },
 403            AuthenticationError => Self::AuthenticationError {
 404                provider,
 405                message: error.message,
 406            },
 407            PaymentRequiredError => Self::AuthenticationError {
 408                provider,
 409                message: format!("Payment required: {}", error.message),
 410            },
 411            PermissionError => Self::PermissionError {
 412                provider,
 413                message: error.message,
 414            },
 415            RequestTimedOut => Self::HttpResponseError {
 416                provider,
 417                status_code: StatusCode::REQUEST_TIMEOUT,
 418                message: error.message,
 419            },
 420            RateLimitError => Self::RateLimitExceeded {
 421                provider,
 422                retry_after: None,
 423            },
 424            ApiError => Self::ApiInternalServerError {
 425                provider,
 426                message: error.message,
 427            },
 428            OverloadedError => Self::ServerOverloaded {
 429                provider,
 430                retry_after: None,
 431            },
 432        }
 433    }
 434}
 435
 436#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 437#[serde(rename_all = "snake_case")]
 438pub enum StopReason {
 439    EndTurn,
 440    MaxTokens,
 441    ToolUse,
 442    Refusal,
 443}
 444
 445#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 446pub struct TokenUsage {
 447    #[serde(default, skip_serializing_if = "is_default")]
 448    pub input_tokens: u64,
 449    #[serde(default, skip_serializing_if = "is_default")]
 450    pub output_tokens: u64,
 451    #[serde(default, skip_serializing_if = "is_default")]
 452    pub cache_creation_input_tokens: u64,
 453    #[serde(default, skip_serializing_if = "is_default")]
 454    pub cache_read_input_tokens: u64,
 455}
 456
 457impl TokenUsage {
 458    pub fn total_tokens(&self) -> u64 {
 459        self.input_tokens
 460            + self.output_tokens
 461            + self.cache_read_input_tokens
 462            + self.cache_creation_input_tokens
 463    }
 464}
 465
 466impl Add<TokenUsage> for TokenUsage {
 467    type Output = Self;
 468
 469    fn add(self, other: Self) -> Self {
 470        Self {
 471            input_tokens: self.input_tokens + other.input_tokens,
 472            output_tokens: self.output_tokens + other.output_tokens,
 473            cache_creation_input_tokens: self.cache_creation_input_tokens
 474                + other.cache_creation_input_tokens,
 475            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
 476        }
 477    }
 478}
 479
 480impl Sub<TokenUsage> for TokenUsage {
 481    type Output = Self;
 482
 483    fn sub(self, other: Self) -> Self {
 484        Self {
 485            input_tokens: self.input_tokens - other.input_tokens,
 486            output_tokens: self.output_tokens - other.output_tokens,
 487            cache_creation_input_tokens: self.cache_creation_input_tokens
 488                - other.cache_creation_input_tokens,
 489            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
 490        }
 491    }
 492}
 493
 494#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 495pub struct LanguageModelToolUseId(Arc<str>);
 496
 497impl fmt::Display for LanguageModelToolUseId {
 498    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 499        write!(f, "{}", self.0)
 500    }
 501}
 502
 503impl<T> From<T> for LanguageModelToolUseId
 504where
 505    T: Into<Arc<str>>,
 506{
 507    fn from(value: T) -> Self {
 508        Self(value.into())
 509    }
 510}
 511
 512#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 513pub struct LanguageModelToolUse {
 514    pub id: LanguageModelToolUseId,
 515    pub name: Arc<str>,
 516    pub raw_input: String,
 517    pub input: serde_json::Value,
 518    pub is_input_complete: bool,
 519    /// Thought signature the model sent us. Some models require that this
 520    /// signature be preserved and sent back in conversation history for validation.
 521    pub thought_signature: Option<String>,
 522}
 523
 524pub struct LanguageModelTextStream {
 525    pub message_id: Option<String>,
 526    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 527    // Has complete token usage after the stream has finished
 528    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 529}
 530
 531impl Default for LanguageModelTextStream {
 532    fn default() -> Self {
 533        Self {
 534            message_id: None,
 535            stream: Box::pin(futures::stream::empty()),
 536            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 537        }
 538    }
 539}
 540
 541pub trait LanguageModel: Send + Sync {
 542    fn id(&self) -> LanguageModelId;
 543    fn name(&self) -> LanguageModelName;
 544    fn provider_id(&self) -> LanguageModelProviderId;
 545    fn provider_name(&self) -> LanguageModelProviderName;
 546    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 547        self.provider_id()
 548    }
 549    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 550        self.provider_name()
 551    }
 552
 553    fn telemetry_id(&self) -> String;
 554
 555    fn api_key(&self, _cx: &App) -> Option<String> {
 556        None
 557    }
 558
 559    /// Whether this model supports images
 560    fn supports_images(&self) -> bool;
 561
 562    /// Whether this model supports tools.
 563    fn supports_tools(&self) -> bool;
 564
 565    /// Whether this model supports choosing which tool to use.
 566    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
 567
 568    /// Returns whether this model supports "burn mode";
 569    fn supports_burn_mode(&self) -> bool {
 570        false
 571    }
 572
 573    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 574        LanguageModelToolSchemaFormat::JsonSchema
 575    }
 576
 577    fn max_token_count(&self) -> u64;
 578    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
 579    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 580        None
 581    }
 582    fn max_output_tokens(&self) -> Option<u64> {
 583        None
 584    }
 585
 586    fn count_tokens(
 587        &self,
 588        request: LanguageModelRequest,
 589        cx: &App,
 590    ) -> BoxFuture<'static, Result<u64>>;
 591
 592    fn stream_completion(
 593        &self,
 594        request: LanguageModelRequest,
 595        cx: &AsyncApp,
 596    ) -> BoxFuture<
 597        'static,
 598        Result<
 599            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 600            LanguageModelCompletionError,
 601        >,
 602    >;
 603
 604    fn stream_completion_text(
 605        &self,
 606        request: LanguageModelRequest,
 607        cx: &AsyncApp,
 608    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
 609        let future = self.stream_completion(request, cx);
 610
 611        async move {
 612            let events = future.await?;
 613            let mut events = events.fuse();
 614            let mut message_id = None;
 615            let mut first_item_text = None;
 616            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 617
 618            if let Some(first_event) = events.next().await {
 619                match first_event {
 620                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
 621                        message_id = Some(id);
 622                    }
 623                    Ok(LanguageModelCompletionEvent::Text(text)) => {
 624                        first_item_text = Some(text);
 625                    }
 626                    _ => (),
 627                }
 628            }
 629
 630            let stream = futures::stream::iter(first_item_text.map(Ok))
 631                .chain(events.filter_map({
 632                    let last_token_usage = last_token_usage.clone();
 633                    move |result| {
 634                        let last_token_usage = last_token_usage.clone();
 635                        async move {
 636                            match result {
 637                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
 638                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
 639                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
 640                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
 641                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
 642                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
 643                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
 644                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
 645                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 646                                    ..
 647                                }) => None,
 648                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
 649                                    *last_token_usage.lock() = token_usage;
 650                                    None
 651                                }
 652                                Err(err) => Some(Err(err)),
 653                            }
 654                        }
 655                    }
 656                }))
 657                .boxed();
 658
 659            Ok(LanguageModelTextStream {
 660                message_id,
 661                stream,
 662                last_token_usage,
 663            })
 664        }
 665        .boxed()
 666    }
 667
 668    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 669        None
 670    }
 671
 672    #[cfg(any(test, feature = "test-support"))]
 673    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
 674        unimplemented!()
 675    }
 676}
 677
 678pub trait LanguageModelExt: LanguageModel {
 679    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
 680        match mode {
 681            CompletionMode::Normal => self.max_token_count(),
 682            CompletionMode::Max => self
 683                .max_token_count_in_burn_mode()
 684                .unwrap_or_else(|| self.max_token_count()),
 685        }
 686    }
 687}
 688impl LanguageModelExt for dyn LanguageModel {}
 689
 690/// An error that occurred when trying to authenticate the language model provider.
 691#[derive(Debug, Error)]
 692pub enum AuthenticateError {
 693    #[error("connection refused")]
 694    ConnectionRefused,
 695    #[error("credentials not found")]
 696    CredentialsNotFound,
 697    #[error(transparent)]
 698    Other(#[from] anyhow::Error),
 699}
 700
 701pub trait LanguageModelProvider: 'static {
 702    fn id(&self) -> LanguageModelProviderId;
 703    fn name(&self) -> LanguageModelProviderName;
 704    fn icon(&self) -> IconName {
 705        IconName::ZedAssistant
 706    }
 707    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 708    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 709    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
 710    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 711        Vec::new()
 712    }
 713    fn is_authenticated(&self, cx: &App) -> bool;
 714    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
 715    fn configuration_view(
 716        &self,
 717        target_agent: ConfigurationViewTargetAgent,
 718        window: &mut Window,
 719        cx: &mut App,
 720    ) -> AnyView;
 721    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
 722}
 723
 724#[derive(Default, Clone)]
 725pub enum ConfigurationViewTargetAgent {
 726    #[default]
 727    ZedAgent,
 728    Other(SharedString),
 729}
 730
 731#[derive(PartialEq, Eq)]
 732pub enum LanguageModelProviderTosView {
 733    /// When there are some past interactions in the Agent Panel.
 734    ThreadEmptyState,
 735    /// When there are no past interactions in the Agent Panel.
 736    ThreadFreshStart,
 737    TextThreadPopup,
 738    Configuration,
 739}
 740
 741pub trait LanguageModelProviderState: 'static {
 742    type ObservableEntity;
 743
 744    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
 745
 746    fn subscribe<T: 'static>(
 747        &self,
 748        cx: &mut gpui::Context<T>,
 749        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
 750    ) -> Option<gpui::Subscription> {
 751        let entity = self.observable_entity()?;
 752        Some(cx.observe(&entity, move |this, _, cx| {
 753            callback(this, cx);
 754        }))
 755    }
 756}
 757
 758#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
 759pub struct LanguageModelId(pub SharedString);
 760
 761#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 762pub struct LanguageModelName(pub SharedString);
 763
 764#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 765pub struct LanguageModelProviderId(pub SharedString);
 766
 767#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 768pub struct LanguageModelProviderName(pub SharedString);
 769
 770impl LanguageModelProviderId {
 771    pub const fn new(id: &'static str) -> Self {
 772        Self(SharedString::new_static(id))
 773    }
 774}
 775
 776impl LanguageModelProviderName {
 777    pub const fn new(id: &'static str) -> Self {
 778        Self(SharedString::new_static(id))
 779    }
 780}
 781
 782impl fmt::Display for LanguageModelProviderId {
 783    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 784        write!(f, "{}", self.0)
 785    }
 786}
 787
 788impl fmt::Display for LanguageModelProviderName {
 789    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 790        write!(f, "{}", self.0)
 791    }
 792}
 793
 794impl From<String> for LanguageModelId {
 795    fn from(value: String) -> Self {
 796        Self(SharedString::from(value))
 797    }
 798}
 799
 800impl From<String> for LanguageModelName {
 801    fn from(value: String) -> Self {
 802        Self(SharedString::from(value))
 803    }
 804}
 805
 806impl From<String> for LanguageModelProviderId {
 807    fn from(value: String) -> Self {
 808        Self(SharedString::from(value))
 809    }
 810}
 811
 812impl From<String> for LanguageModelProviderName {
 813    fn from(value: String) -> Self {
 814        Self(SharedString::from(value))
 815    }
 816}
 817
 818impl From<Arc<str>> for LanguageModelProviderId {
 819    fn from(value: Arc<str>) -> Self {
 820        Self(SharedString::from(value))
 821    }
 822}
 823
 824impl From<Arc<str>> for LanguageModelProviderName {
 825    fn from(value: Arc<str>) -> Self {
 826        Self(SharedString::from(value))
 827    }
 828}
 829
 830#[cfg(test)]
 831mod tests {
 832    use super::*;
 833
 834    #[test]
 835    fn test_from_cloud_failure_with_upstream_http_error() {
 836        let error = LanguageModelCompletionError::from_cloud_failure(
 837            String::from("anthropic").into(),
 838            "upstream_http_error".to_string(),
 839            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 840            None,
 841        );
 842
 843        match error {
 844            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 845                assert_eq!(provider.0, "anthropic");
 846            }
 847            _ => panic!(
 848                "Expected ServerOverloaded error for 503 status, got: {:?}",
 849                error
 850            ),
 851        }
 852
 853        let error = LanguageModelCompletionError::from_cloud_failure(
 854            String::from("anthropic").into(),
 855            "upstream_http_error".to_string(),
 856            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
 857            None,
 858        );
 859
 860        match error {
 861            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 862                assert_eq!(provider.0, "anthropic");
 863                assert_eq!(message, "Internal server error");
 864            }
 865            _ => panic!(
 866                "Expected ApiInternalServerError for 500 status, got: {:?}",
 867                error
 868            ),
 869        }
 870    }
 871
 872    #[test]
 873    fn test_from_cloud_failure_with_standard_format() {
 874        let error = LanguageModelCompletionError::from_cloud_failure(
 875            String::from("anthropic").into(),
 876            "upstream_http_503".to_string(),
 877            "Service unavailable".to_string(),
 878            None,
 879        );
 880
 881        match error {
 882            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 883                assert_eq!(provider.0, "anthropic");
 884            }
 885            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
 886        }
 887    }
 888
 889    #[test]
 890    fn test_upstream_http_error_connection_timeout() {
 891        let error = LanguageModelCompletionError::from_cloud_failure(
 892            String::from("anthropic").into(),
 893            "upstream_http_error".to_string(),
 894            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 895            None,
 896        );
 897
 898        match error {
 899            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 900                assert_eq!(provider.0, "anthropic");
 901            }
 902            _ => panic!(
 903                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
 904                error
 905            ),
 906        }
 907
 908        let error = LanguageModelCompletionError::from_cloud_failure(
 909            String::from("anthropic").into(),
 910            "upstream_http_error".to_string(),
 911            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
 912            None,
 913        );
 914
 915        match error {
 916            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 917                assert_eq!(provider.0, "anthropic");
 918                assert_eq!(
 919                    message,
 920                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
 921                );
 922            }
 923            _ => panic!(
 924                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
 925                error
 926            ),
 927        }
 928    }
 929
 930    #[test]
 931    fn test_language_model_tool_use_serializes_with_signature() {
 932        use serde_json::json;
 933
 934        let tool_use = LanguageModelToolUse {
 935            id: LanguageModelToolUseId::from("test_id"),
 936            name: "test_tool".into(),
 937            raw_input: json!({"arg": "value"}).to_string(),
 938            input: json!({"arg": "value"}),
 939            is_input_complete: true,
 940            thought_signature: Some("test_signature".to_string()),
 941        };
 942
 943        let serialized = serde_json::to_value(&tool_use).unwrap();
 944
 945        assert_eq!(serialized["id"], "test_id");
 946        assert_eq!(serialized["name"], "test_tool");
 947        assert_eq!(serialized["thought_signature"], "test_signature");
 948    }
 949
 950    #[test]
 951    fn test_language_model_tool_use_deserializes_with_missing_signature() {
 952        use serde_json::json;
 953
 954        let json = json!({
 955            "id": "test_id",
 956            "name": "test_tool",
 957            "raw_input": "{\"arg\":\"value\"}",
 958            "input": {"arg": "value"},
 959            "is_input_complete": true
 960        });
 961
 962        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
 963
 964        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
 965        assert_eq!(tool_use.name.as_ref(), "test_tool");
 966        assert_eq!(tool_use.thought_signature, None);
 967    }
 968
 969    #[test]
 970    fn test_language_model_tool_use_round_trip_with_signature() {
 971        use serde_json::json;
 972
 973        let original = LanguageModelToolUse {
 974            id: LanguageModelToolUseId::from("round_trip_id"),
 975            name: "round_trip_tool".into(),
 976            raw_input: json!({"key": "value"}).to_string(),
 977            input: json!({"key": "value"}),
 978            is_input_complete: true,
 979            thought_signature: Some("round_trip_sig".to_string()),
 980        };
 981
 982        let serialized = serde_json::to_value(&original).unwrap();
 983        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
 984
 985        assert_eq!(deserialized.id, original.id);
 986        assert_eq!(deserialized.name, original.name);
 987        assert_eq!(deserialized.thought_signature, original.thought_signature);
 988    }
 989
 990    #[test]
 991    fn test_language_model_tool_use_round_trip_without_signature() {
 992        use serde_json::json;
 993
 994        let original = LanguageModelToolUse {
 995            id: LanguageModelToolUseId::from("no_sig_id"),
 996            name: "no_sig_tool".into(),
 997            raw_input: json!({"arg": "value"}).to_string(),
 998            input: json!({"arg": "value"}),
 999            is_input_complete: true,
1000            thought_signature: None,
1001        };
1002
1003        let serialized = serde_json::to_value(&original).unwrap();
1004        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1005
1006        assert_eq!(deserialized.id, original.id);
1007        assert_eq!(deserialized.name, original.name);
1008        assert_eq!(deserialized.thought_signature, None);
1009    }
1010}