language_model.rs

   1mod model;
   2mod rate_limiter;
   3mod registry;
   4mod request;
   5mod role;
   6mod telemetry;
   7pub mod tool_schema;
   8
   9#[cfg(any(test, feature = "test-support"))]
  10pub mod fake_provider;
  11
  12use anthropic::{AnthropicError, parse_prompt_too_long};
  13use anyhow::{Result, anyhow};
  14use client::Client;
  15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
  16use futures::FutureExt;
  17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
  18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
  19use http_client::{StatusCode, http};
  20use icons::IconName;
  21use open_router::OpenRouterError;
  22use parking_lot::Mutex;
  23use serde::{Deserialize, Serialize};
  24pub use settings::LanguageModelCacheConfiguration;
  25use std::ops::{Add, Sub};
  26use std::str::FromStr;
  27use std::sync::Arc;
  28use std::time::Duration;
  29use std::{fmt, io};
  30use thiserror::Error;
  31use util::serde::is_default;
  32
  33pub use crate::model::*;
  34pub use crate::rate_limiter::*;
  35pub use crate::registry::*;
  36pub use crate::request::*;
  37pub use crate::role::*;
  38pub use crate::telemetry::*;
  39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
  40
  41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
  42    LanguageModelProviderId::new("anthropic");
  43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
  44    LanguageModelProviderName::new("Anthropic");
  45
  46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
  47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
  48    LanguageModelProviderName::new("Google AI");
  49
  50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
  51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
  52    LanguageModelProviderName::new("OpenAI");
  53
  54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
  55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
  56
  57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
  58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
  59    LanguageModelProviderName::new("Zed");
  60
  61pub fn init(client: Arc<Client>, cx: &mut App) {
  62    init_settings(cx);
  63    RefreshLlmTokenListener::register(client, cx);
  64}
  65
  66pub fn init_settings(cx: &mut App) {
  67    registry::init(cx);
  68}
  69
  70/// A completion event from a language model.
  71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
  72pub enum LanguageModelCompletionEvent {
  73    StatusUpdate(CompletionRequestStatus),
  74    Stop(StopReason),
  75    Text(String),
  76    Thinking {
  77        text: String,
  78        signature: Option<String>,
  79    },
  80    RedactedThinking {
  81        data: String,
  82    },
  83    ToolUse(LanguageModelToolUse),
  84    ToolUseJsonParseError {
  85        id: LanguageModelToolUseId,
  86        tool_name: Arc<str>,
  87        raw_input: Arc<str>,
  88        json_parse_error: String,
  89    },
  90    StartMessage {
  91        message_id: String,
  92    },
  93    UsageUpdate(TokenUsage),
  94}
  95
  96#[derive(Error, Debug)]
  97pub enum LanguageModelCompletionError {
  98    #[error("prompt too large for context window")]
  99    PromptTooLarge { tokens: Option<u64> },
 100    #[error("missing {provider} API key")]
 101    NoApiKey { provider: LanguageModelProviderName },
 102    #[error("{provider}'s API rate limit exceeded")]
 103    RateLimitExceeded {
 104        provider: LanguageModelProviderName,
 105        retry_after: Option<Duration>,
 106    },
 107    #[error("{provider}'s API servers are overloaded right now")]
 108    ServerOverloaded {
 109        provider: LanguageModelProviderName,
 110        retry_after: Option<Duration>,
 111    },
 112    #[error("{provider}'s API server reported an internal server error: {message}")]
 113    ApiInternalServerError {
 114        provider: LanguageModelProviderName,
 115        message: String,
 116    },
 117    #[error("{message}")]
 118    UpstreamProviderError {
 119        message: String,
 120        status: StatusCode,
 121        retry_after: Option<Duration>,
 122    },
 123    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
 124    HttpResponseError {
 125        provider: LanguageModelProviderName,
 126        status_code: StatusCode,
 127        message: String,
 128    },
 129
 130    // Client errors
 131    #[error("invalid request format to {provider}'s API: {message}")]
 132    BadRequestFormat {
 133        provider: LanguageModelProviderName,
 134        message: String,
 135    },
 136    #[error("authentication error with {provider}'s API: {message}")]
 137    AuthenticationError {
 138        provider: LanguageModelProviderName,
 139        message: String,
 140    },
 141    #[error("Permission error with {provider}'s API: {message}")]
 142    PermissionError {
 143        provider: LanguageModelProviderName,
 144        message: String,
 145    },
 146    #[error("language model provider API endpoint not found")]
 147    ApiEndpointNotFound { provider: LanguageModelProviderName },
 148    #[error("I/O error reading response from {provider}'s API")]
 149    ApiReadResponseError {
 150        provider: LanguageModelProviderName,
 151        #[source]
 152        error: io::Error,
 153    },
 154    #[error("error serializing request to {provider} API")]
 155    SerializeRequest {
 156        provider: LanguageModelProviderName,
 157        #[source]
 158        error: serde_json::Error,
 159    },
 160    #[error("error building request body to {provider} API")]
 161    BuildRequestBody {
 162        provider: LanguageModelProviderName,
 163        #[source]
 164        error: http::Error,
 165    },
 166    #[error("error sending HTTP request to {provider} API")]
 167    HttpSend {
 168        provider: LanguageModelProviderName,
 169        #[source]
 170        error: anyhow::Error,
 171    },
 172    #[error("error deserializing {provider} API response")]
 173    DeserializeResponse {
 174        provider: LanguageModelProviderName,
 175        #[source]
 176        error: serde_json::Error,
 177    },
 178
 179    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
 180    #[error(transparent)]
 181    Other(#[from] anyhow::Error),
 182}
 183
 184impl LanguageModelCompletionError {
 185    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
 186        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
 187        let upstream_status = error_json
 188            .get("upstream_status")
 189            .and_then(|v| v.as_u64())
 190            .and_then(|status| u16::try_from(status).ok())
 191            .and_then(|status| StatusCode::from_u16(status).ok())?;
 192        let inner_message = error_json
 193            .get("message")
 194            .and_then(|v| v.as_str())
 195            .unwrap_or(message)
 196            .to_string();
 197        Some((upstream_status, inner_message))
 198    }
 199
 200    pub fn from_cloud_failure(
 201        upstream_provider: LanguageModelProviderName,
 202        code: String,
 203        message: String,
 204        retry_after: Option<Duration>,
 205    ) -> Self {
 206        if let Some(tokens) = parse_prompt_too_long(&message) {
 207            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
 208            // to be reported. This is a temporary workaround to handle this in the case where the
 209            // token limit has been exceeded.
 210            Self::PromptTooLarge {
 211                tokens: Some(tokens),
 212            }
 213        } else if code == "upstream_http_error" {
 214            if let Some((upstream_status, inner_message)) =
 215                Self::parse_upstream_error_json(&message)
 216            {
 217                return Self::from_http_status(
 218                    upstream_provider,
 219                    upstream_status,
 220                    inner_message,
 221                    retry_after,
 222                );
 223            }
 224            anyhow!("completion request failed, code: {code}, message: {message}").into()
 225        } else if let Some(status_code) = code
 226            .strip_prefix("upstream_http_")
 227            .and_then(|code| StatusCode::from_str(code).ok())
 228        {
 229            Self::from_http_status(upstream_provider, status_code, message, retry_after)
 230        } else if let Some(status_code) = code
 231            .strip_prefix("http_")
 232            .and_then(|code| StatusCode::from_str(code).ok())
 233        {
 234            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
 235        } else {
 236            anyhow!("completion request failed, code: {code}, message: {message}").into()
 237        }
 238    }
 239
 240    pub fn from_http_status(
 241        provider: LanguageModelProviderName,
 242        status_code: StatusCode,
 243        message: String,
 244        retry_after: Option<Duration>,
 245    ) -> Self {
 246        match status_code {
 247            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
 248            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
 249            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
 250            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
 251            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
 252                tokens: parse_prompt_too_long(&message),
 253            },
 254            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
 255                provider,
 256                retry_after,
 257            },
 258            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
 259            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
 260                provider,
 261                retry_after,
 262            },
 263            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
 264                provider,
 265                retry_after,
 266            },
 267            _ => Self::HttpResponseError {
 268                provider,
 269                status_code,
 270                message,
 271            },
 272        }
 273    }
 274}
 275
 276impl From<AnthropicError> for LanguageModelCompletionError {
 277    fn from(error: AnthropicError) -> Self {
 278        let provider = ANTHROPIC_PROVIDER_NAME;
 279        match error {
 280            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 281            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 282            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
 283            AnthropicError::DeserializeResponse(error) => {
 284                Self::DeserializeResponse { provider, error }
 285            }
 286            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 287            AnthropicError::HttpResponseError {
 288                status_code,
 289                message,
 290            } => Self::HttpResponseError {
 291                provider,
 292                status_code,
 293                message,
 294            },
 295            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
 296                provider,
 297                retry_after: Some(retry_after),
 298            },
 299            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 300                provider,
 301                retry_after,
 302            },
 303            AnthropicError::ApiError(api_error) => api_error.into(),
 304        }
 305    }
 306}
 307
 308impl From<anthropic::ApiError> for LanguageModelCompletionError {
 309    fn from(error: anthropic::ApiError) -> Self {
 310        use anthropic::ApiErrorCode::*;
 311        let provider = ANTHROPIC_PROVIDER_NAME;
 312        match error.code() {
 313            Some(code) => match code {
 314                InvalidRequestError => Self::BadRequestFormat {
 315                    provider,
 316                    message: error.message,
 317                },
 318                AuthenticationError => Self::AuthenticationError {
 319                    provider,
 320                    message: error.message,
 321                },
 322                PermissionError => Self::PermissionError {
 323                    provider,
 324                    message: error.message,
 325                },
 326                NotFoundError => Self::ApiEndpointNotFound { provider },
 327                RequestTooLarge => Self::PromptTooLarge {
 328                    tokens: parse_prompt_too_long(&error.message),
 329                },
 330                RateLimitError => Self::RateLimitExceeded {
 331                    provider,
 332                    retry_after: None,
 333                },
 334                ApiError => Self::ApiInternalServerError {
 335                    provider,
 336                    message: error.message,
 337                },
 338                OverloadedError => Self::ServerOverloaded {
 339                    provider,
 340                    retry_after: None,
 341                },
 342            },
 343            None => Self::Other(error.into()),
 344        }
 345    }
 346}
 347
 348impl From<open_ai::RequestError> for LanguageModelCompletionError {
 349    fn from(error: open_ai::RequestError) -> Self {
 350        match error {
 351            open_ai::RequestError::HttpResponseError {
 352                provider,
 353                status_code,
 354                body,
 355                headers,
 356            } => {
 357                let retry_after = headers
 358                    .get(http::header::RETRY_AFTER)
 359                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
 360                    .map(Duration::from_secs);
 361
 362                Self::from_http_status(provider.into(), status_code, body, retry_after)
 363            }
 364            open_ai::RequestError::Other(e) => Self::Other(e),
 365        }
 366    }
 367}
 368
 369impl From<OpenRouterError> for LanguageModelCompletionError {
 370    fn from(error: OpenRouterError) -> Self {
 371        let provider = LanguageModelProviderName::new("OpenRouter");
 372        match error {
 373            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 374            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 375            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
 376            OpenRouterError::DeserializeResponse(error) => {
 377                Self::DeserializeResponse { provider, error }
 378            }
 379            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 380            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
 381                provider,
 382                retry_after: Some(retry_after),
 383            },
 384            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 385                provider,
 386                retry_after,
 387            },
 388            OpenRouterError::ApiError(api_error) => api_error.into(),
 389        }
 390    }
 391}
 392
 393impl From<open_router::ApiError> for LanguageModelCompletionError {
 394    fn from(error: open_router::ApiError) -> Self {
 395        use open_router::ApiErrorCode::*;
 396        let provider = LanguageModelProviderName::new("OpenRouter");
 397        match error.code {
 398            InvalidRequestError => Self::BadRequestFormat {
 399                provider,
 400                message: error.message,
 401            },
 402            AuthenticationError => Self::AuthenticationError {
 403                provider,
 404                message: error.message,
 405            },
 406            PaymentRequiredError => Self::AuthenticationError {
 407                provider,
 408                message: format!("Payment required: {}", error.message),
 409            },
 410            PermissionError => Self::PermissionError {
 411                provider,
 412                message: error.message,
 413            },
 414            RequestTimedOut => Self::HttpResponseError {
 415                provider,
 416                status_code: StatusCode::REQUEST_TIMEOUT,
 417                message: error.message,
 418            },
 419            RateLimitError => Self::RateLimitExceeded {
 420                provider,
 421                retry_after: None,
 422            },
 423            ApiError => Self::ApiInternalServerError {
 424                provider,
 425                message: error.message,
 426            },
 427            OverloadedError => Self::ServerOverloaded {
 428                provider,
 429                retry_after: None,
 430            },
 431        }
 432    }
 433}
 434
 435#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 436#[serde(rename_all = "snake_case")]
 437pub enum StopReason {
 438    EndTurn,
 439    MaxTokens,
 440    ToolUse,
 441    Refusal,
 442}
 443
 444#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 445pub struct TokenUsage {
 446    #[serde(default, skip_serializing_if = "is_default")]
 447    pub input_tokens: u64,
 448    #[serde(default, skip_serializing_if = "is_default")]
 449    pub output_tokens: u64,
 450    #[serde(default, skip_serializing_if = "is_default")]
 451    pub cache_creation_input_tokens: u64,
 452    #[serde(default, skip_serializing_if = "is_default")]
 453    pub cache_read_input_tokens: u64,
 454}
 455
 456impl TokenUsage {
 457    pub fn total_tokens(&self) -> u64 {
 458        self.input_tokens
 459            + self.output_tokens
 460            + self.cache_read_input_tokens
 461            + self.cache_creation_input_tokens
 462    }
 463}
 464
 465impl Add<TokenUsage> for TokenUsage {
 466    type Output = Self;
 467
 468    fn add(self, other: Self) -> Self {
 469        Self {
 470            input_tokens: self.input_tokens + other.input_tokens,
 471            output_tokens: self.output_tokens + other.output_tokens,
 472            cache_creation_input_tokens: self.cache_creation_input_tokens
 473                + other.cache_creation_input_tokens,
 474            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
 475        }
 476    }
 477}
 478
 479impl Sub<TokenUsage> for TokenUsage {
 480    type Output = Self;
 481
 482    fn sub(self, other: Self) -> Self {
 483        Self {
 484            input_tokens: self.input_tokens - other.input_tokens,
 485            output_tokens: self.output_tokens - other.output_tokens,
 486            cache_creation_input_tokens: self.cache_creation_input_tokens
 487                - other.cache_creation_input_tokens,
 488            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
 489        }
 490    }
 491}
 492
 493#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 494pub struct LanguageModelToolUseId(Arc<str>);
 495
 496impl fmt::Display for LanguageModelToolUseId {
 497    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 498        write!(f, "{}", self.0)
 499    }
 500}
 501
 502impl<T> From<T> for LanguageModelToolUseId
 503where
 504    T: Into<Arc<str>>,
 505{
 506    fn from(value: T) -> Self {
 507        Self(value.into())
 508    }
 509}
 510
 511#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 512pub struct LanguageModelToolUse {
 513    pub id: LanguageModelToolUseId,
 514    pub name: Arc<str>,
 515    pub raw_input: String,
 516    pub input: serde_json::Value,
 517    pub is_input_complete: bool,
 518    /// Thought signature the model sent us. Some models require that this
 519    /// signature be preserved and sent back in conversation history for validation.
 520    pub thought_signature: Option<String>,
 521}
 522
 523pub struct LanguageModelTextStream {
 524    pub message_id: Option<String>,
 525    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 526    // Has complete token usage after the stream has finished
 527    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 528}
 529
 530impl Default for LanguageModelTextStream {
 531    fn default() -> Self {
 532        Self {
 533            message_id: None,
 534            stream: Box::pin(futures::stream::empty()),
 535            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 536        }
 537    }
 538}
 539
 540pub trait LanguageModel: Send + Sync {
 541    fn id(&self) -> LanguageModelId;
 542    fn name(&self) -> LanguageModelName;
 543    fn provider_id(&self) -> LanguageModelProviderId;
 544    fn provider_name(&self) -> LanguageModelProviderName;
 545    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 546        self.provider_id()
 547    }
 548    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 549        self.provider_name()
 550    }
 551
 552    fn telemetry_id(&self) -> String;
 553
 554    fn api_key(&self, _cx: &App) -> Option<String> {
 555        None
 556    }
 557
 558    /// Whether this model supports images
 559    fn supports_images(&self) -> bool;
 560
 561    /// Whether this model supports tools.
 562    fn supports_tools(&self) -> bool;
 563
 564    /// Whether this model supports choosing which tool to use.
 565    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
 566
 567    /// Returns whether this model supports "burn mode";
 568    fn supports_burn_mode(&self) -> bool {
 569        false
 570    }
 571
 572    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 573        LanguageModelToolSchemaFormat::JsonSchema
 574    }
 575
 576    fn max_token_count(&self) -> u64;
 577    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
 578    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 579        None
 580    }
 581    fn max_output_tokens(&self) -> Option<u64> {
 582        None
 583    }
 584
 585    fn count_tokens(
 586        &self,
 587        request: LanguageModelRequest,
 588        cx: &App,
 589    ) -> BoxFuture<'static, Result<u64>>;
 590
 591    fn stream_completion(
 592        &self,
 593        request: LanguageModelRequest,
 594        cx: &AsyncApp,
 595    ) -> BoxFuture<
 596        'static,
 597        Result<
 598            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 599            LanguageModelCompletionError,
 600        >,
 601    >;
 602
 603    fn stream_completion_text(
 604        &self,
 605        request: LanguageModelRequest,
 606        cx: &AsyncApp,
 607    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
 608        let future = self.stream_completion(request, cx);
 609
 610        async move {
 611            let events = future.await?;
 612            let mut events = events.fuse();
 613            let mut message_id = None;
 614            let mut first_item_text = None;
 615            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 616
 617            if let Some(first_event) = events.next().await {
 618                match first_event {
 619                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
 620                        message_id = Some(id);
 621                    }
 622                    Ok(LanguageModelCompletionEvent::Text(text)) => {
 623                        first_item_text = Some(text);
 624                    }
 625                    _ => (),
 626                }
 627            }
 628
 629            let stream = futures::stream::iter(first_item_text.map(Ok))
 630                .chain(events.filter_map({
 631                    let last_token_usage = last_token_usage.clone();
 632                    move |result| {
 633                        let last_token_usage = last_token_usage.clone();
 634                        async move {
 635                            match result {
 636                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
 637                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
 638                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
 639                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
 640                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
 641                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
 642                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
 643                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 644                                    ..
 645                                }) => None,
 646                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
 647                                    *last_token_usage.lock() = token_usage;
 648                                    None
 649                                }
 650                                Err(err) => Some(Err(err)),
 651                            }
 652                        }
 653                    }
 654                }))
 655                .boxed();
 656
 657            Ok(LanguageModelTextStream {
 658                message_id,
 659                stream,
 660                last_token_usage,
 661            })
 662        }
 663        .boxed()
 664    }
 665
 666    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 667        None
 668    }
 669
 670    #[cfg(any(test, feature = "test-support"))]
 671    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
 672        unimplemented!()
 673    }
 674}
 675
 676pub trait LanguageModelExt: LanguageModel {
 677    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
 678        match mode {
 679            CompletionMode::Normal => self.max_token_count(),
 680            CompletionMode::Max => self
 681                .max_token_count_in_burn_mode()
 682                .unwrap_or_else(|| self.max_token_count()),
 683        }
 684    }
 685}
 686impl LanguageModelExt for dyn LanguageModel {}
 687
 688/// An error that occurred when trying to authenticate the language model provider.
 689#[derive(Debug, Error)]
 690pub enum AuthenticateError {
 691    #[error("connection refused")]
 692    ConnectionRefused,
 693    #[error("credentials not found")]
 694    CredentialsNotFound,
 695    #[error(transparent)]
 696    Other(#[from] anyhow::Error),
 697}
 698
 699pub trait LanguageModelProvider: 'static {
 700    fn id(&self) -> LanguageModelProviderId;
 701    fn name(&self) -> LanguageModelProviderName;
 702    fn icon(&self) -> IconName {
 703        IconName::ZedAssistant
 704    }
 705    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 706    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 707    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
 708    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 709        Vec::new()
 710    }
 711    fn is_authenticated(&self, cx: &App) -> bool;
 712    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
 713    fn configuration_view(
 714        &self,
 715        target_agent: ConfigurationViewTargetAgent,
 716        window: &mut Window,
 717        cx: &mut App,
 718    ) -> AnyView;
 719    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
 720}
 721
 722#[derive(Default, Clone)]
 723pub enum ConfigurationViewTargetAgent {
 724    #[default]
 725    ZedAgent,
 726    Other(SharedString),
 727}
 728
 729#[derive(PartialEq, Eq)]
 730pub enum LanguageModelProviderTosView {
 731    /// When there are some past interactions in the Agent Panel.
 732    ThreadEmptyState,
 733    /// When there are no past interactions in the Agent Panel.
 734    ThreadFreshStart,
 735    TextThreadPopup,
 736    Configuration,
 737}
 738
 739pub trait LanguageModelProviderState: 'static {
 740    type ObservableEntity;
 741
 742    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
 743
 744    fn subscribe<T: 'static>(
 745        &self,
 746        cx: &mut gpui::Context<T>,
 747        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
 748    ) -> Option<gpui::Subscription> {
 749        let entity = self.observable_entity()?;
 750        Some(cx.observe(&entity, move |this, _, cx| {
 751            callback(this, cx);
 752        }))
 753    }
 754}
 755
 756#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
 757pub struct LanguageModelId(pub SharedString);
 758
 759#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 760pub struct LanguageModelName(pub SharedString);
 761
 762#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 763pub struct LanguageModelProviderId(pub SharedString);
 764
 765#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 766pub struct LanguageModelProviderName(pub SharedString);
 767
 768impl LanguageModelProviderId {
 769    pub const fn new(id: &'static str) -> Self {
 770        Self(SharedString::new_static(id))
 771    }
 772}
 773
 774impl LanguageModelProviderName {
 775    pub const fn new(id: &'static str) -> Self {
 776        Self(SharedString::new_static(id))
 777    }
 778}
 779
 780impl fmt::Display for LanguageModelProviderId {
 781    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 782        write!(f, "{}", self.0)
 783    }
 784}
 785
 786impl fmt::Display for LanguageModelProviderName {
 787    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 788        write!(f, "{}", self.0)
 789    }
 790}
 791
 792impl From<String> for LanguageModelId {
 793    fn from(value: String) -> Self {
 794        Self(SharedString::from(value))
 795    }
 796}
 797
 798impl From<String> for LanguageModelName {
 799    fn from(value: String) -> Self {
 800        Self(SharedString::from(value))
 801    }
 802}
 803
 804impl From<String> for LanguageModelProviderId {
 805    fn from(value: String) -> Self {
 806        Self(SharedString::from(value))
 807    }
 808}
 809
 810impl From<String> for LanguageModelProviderName {
 811    fn from(value: String) -> Self {
 812        Self(SharedString::from(value))
 813    }
 814}
 815
 816impl From<Arc<str>> for LanguageModelProviderId {
 817    fn from(value: Arc<str>) -> Self {
 818        Self(SharedString::from(value))
 819    }
 820}
 821
 822impl From<Arc<str>> for LanguageModelProviderName {
 823    fn from(value: Arc<str>) -> Self {
 824        Self(SharedString::from(value))
 825    }
 826}
 827
 828#[cfg(test)]
 829mod tests {
 830    use super::*;
 831
 832    #[test]
 833    fn test_from_cloud_failure_with_upstream_http_error() {
 834        let error = LanguageModelCompletionError::from_cloud_failure(
 835            String::from("anthropic").into(),
 836            "upstream_http_error".to_string(),
 837            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 838            None,
 839        );
 840
 841        match error {
 842            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 843                assert_eq!(provider.0, "anthropic");
 844            }
 845            _ => panic!(
 846                "Expected ServerOverloaded error for 503 status, got: {:?}",
 847                error
 848            ),
 849        }
 850
 851        let error = LanguageModelCompletionError::from_cloud_failure(
 852            String::from("anthropic").into(),
 853            "upstream_http_error".to_string(),
 854            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
 855            None,
 856        );
 857
 858        match error {
 859            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 860                assert_eq!(provider.0, "anthropic");
 861                assert_eq!(message, "Internal server error");
 862            }
 863            _ => panic!(
 864                "Expected ApiInternalServerError for 500 status, got: {:?}",
 865                error
 866            ),
 867        }
 868    }
 869
 870    #[test]
 871    fn test_from_cloud_failure_with_standard_format() {
 872        let error = LanguageModelCompletionError::from_cloud_failure(
 873            String::from("anthropic").into(),
 874            "upstream_http_503".to_string(),
 875            "Service unavailable".to_string(),
 876            None,
 877        );
 878
 879        match error {
 880            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 881                assert_eq!(provider.0, "anthropic");
 882            }
 883            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
 884        }
 885    }
 886
 887    #[test]
 888    fn test_upstream_http_error_connection_timeout() {
 889        let error = LanguageModelCompletionError::from_cloud_failure(
 890            String::from("anthropic").into(),
 891            "upstream_http_error".to_string(),
 892            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 893            None,
 894        );
 895
 896        match error {
 897            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 898                assert_eq!(provider.0, "anthropic");
 899            }
 900            _ => panic!(
 901                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
 902                error
 903            ),
 904        }
 905
 906        let error = LanguageModelCompletionError::from_cloud_failure(
 907            String::from("anthropic").into(),
 908            "upstream_http_error".to_string(),
 909            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
 910            None,
 911        );
 912
 913        match error {
 914            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 915                assert_eq!(provider.0, "anthropic");
 916                assert_eq!(
 917                    message,
 918                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
 919                );
 920            }
 921            _ => panic!(
 922                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
 923                error
 924            ),
 925        }
 926    }
 927
 928    #[test]
 929    fn test_language_model_tool_use_serializes_with_signature() {
 930        use serde_json::json;
 931
 932        let tool_use = LanguageModelToolUse {
 933            id: LanguageModelToolUseId::from("test_id"),
 934            name: "test_tool".into(),
 935            raw_input: json!({"arg": "value"}).to_string(),
 936            input: json!({"arg": "value"}),
 937            is_input_complete: true,
 938            thought_signature: Some("test_signature".to_string()),
 939        };
 940
 941        let serialized = serde_json::to_value(&tool_use).unwrap();
 942
 943        assert_eq!(serialized["id"], "test_id");
 944        assert_eq!(serialized["name"], "test_tool");
 945        assert_eq!(serialized["thought_signature"], "test_signature");
 946    }
 947
 948    #[test]
 949    fn test_language_model_tool_use_deserializes_with_missing_signature() {
 950        use serde_json::json;
 951
 952        let json = json!({
 953            "id": "test_id",
 954            "name": "test_tool",
 955            "raw_input": "{\"arg\":\"value\"}",
 956            "input": {"arg": "value"},
 957            "is_input_complete": true
 958        });
 959
 960        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
 961
 962        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
 963        assert_eq!(tool_use.name.as_ref(), "test_tool");
 964        assert_eq!(tool_use.thought_signature, None);
 965    }
 966
 967    #[test]
 968    fn test_language_model_tool_use_round_trip_with_signature() {
 969        use serde_json::json;
 970
 971        let original = LanguageModelToolUse {
 972            id: LanguageModelToolUseId::from("round_trip_id"),
 973            name: "round_trip_tool".into(),
 974            raw_input: json!({"key": "value"}).to_string(),
 975            input: json!({"key": "value"}),
 976            is_input_complete: true,
 977            thought_signature: Some("round_trip_sig".to_string()),
 978        };
 979
 980        let serialized = serde_json::to_value(&original).unwrap();
 981        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
 982
 983        assert_eq!(deserialized.id, original.id);
 984        assert_eq!(deserialized.name, original.name);
 985        assert_eq!(deserialized.thought_signature, original.thought_signature);
 986    }
 987
 988    #[test]
 989    fn test_language_model_tool_use_round_trip_without_signature() {
 990        use serde_json::json;
 991
 992        let original = LanguageModelToolUse {
 993            id: LanguageModelToolUseId::from("no_sig_id"),
 994            name: "no_sig_tool".into(),
 995            raw_input: json!({"key": "value"}).to_string(),
 996            input: json!({"key": "value"}),
 997            is_input_complete: true,
 998            thought_signature: None,
 999        };
1000
1001        let serialized = serde_json::to_value(&original).unwrap();
1002        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1003
1004        assert_eq!(deserialized.id, original.id);
1005        assert_eq!(deserialized.name, original.name);
1006        assert_eq!(deserialized.thought_signature, None);
1007    }
1008}