language_model.rs

   1mod model;
   2mod rate_limiter;
   3mod registry;
   4mod request;
   5mod role;
   6mod telemetry;
   7pub mod tool_schema;
   8
   9#[cfg(any(test, feature = "test-support"))]
  10pub mod fake_provider;
  11
  12use anthropic::{AnthropicError, parse_prompt_too_long};
  13use anyhow::{Result, anyhow};
  14use client::Client;
  15use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
  16use futures::FutureExt;
  17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
  18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
  19use http_client::{StatusCode, http};
  20use icons::IconName;
  21use open_router::OpenRouterError;
  22use parking_lot::Mutex;
  23use serde::{Deserialize, Serialize};
  24pub use settings::LanguageModelCacheConfiguration;
  25use std::ops::{Add, Sub};
  26use std::str::FromStr;
  27use std::sync::Arc;
  28use std::time::Duration;
  29use std::{fmt, io};
  30use thiserror::Error;
  31use util::serde::is_default;
  32
  33pub use crate::model::*;
  34pub use crate::rate_limiter::*;
  35pub use crate::registry::*;
  36pub use crate::request::*;
  37pub use crate::role::*;
  38pub use crate::telemetry::*;
  39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
  40
  41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
  42    LanguageModelProviderId::new("anthropic");
  43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
  44    LanguageModelProviderName::new("Anthropic");
  45
  46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
  47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
  48    LanguageModelProviderName::new("Google AI");
  49
  50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
  51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
  52    LanguageModelProviderName::new("OpenAI");
  53
  54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
  55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
  56
  57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
  58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
  59    LanguageModelProviderName::new("Zed");
  60
  61pub fn init(client: Arc<Client>, cx: &mut App) {
  62    init_settings(cx);
  63    RefreshLlmTokenListener::register(client, cx);
  64}
  65
  66pub fn init_settings(cx: &mut App) {
  67    registry::init(cx);
  68}
  69
  70/// A completion event from a language model.
  71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
  72pub enum LanguageModelCompletionEvent {
  73    Queued {
  74        position: usize,
  75    },
  76    Started,
  77    UsageUpdated {
  78        amount: usize,
  79        limit: UsageLimit,
  80    },
  81    ToolUseLimitReached,
  82    Stop(StopReason),
  83    Text(String),
  84    Thinking {
  85        text: String,
  86        signature: Option<String>,
  87    },
  88    RedactedThinking {
  89        data: String,
  90    },
  91    ToolUse(LanguageModelToolUse),
  92    ToolUseJsonParseError {
  93        id: LanguageModelToolUseId,
  94        tool_name: Arc<str>,
  95        raw_input: Arc<str>,
  96        json_parse_error: String,
  97    },
  98    StartMessage {
  99        message_id: String,
 100    },
 101    UsageUpdate(TokenUsage),
 102}
 103
 104impl LanguageModelCompletionEvent {
 105    pub fn from_completion_request_status(
 106        status: CompletionRequestStatus,
 107        upstream_provider: LanguageModelProviderName,
 108    ) -> Result<Self, LanguageModelCompletionError> {
 109        match status {
 110            CompletionRequestStatus::Queued { position } => {
 111                Ok(LanguageModelCompletionEvent::Queued { position })
 112            }
 113            CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
 114            CompletionRequestStatus::UsageUpdated { amount, limit } => {
 115                Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit })
 116            }
 117            CompletionRequestStatus::ToolUseLimitReached => {
 118                Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
 119            }
 120            CompletionRequestStatus::Failed {
 121                code,
 122                message,
 123                request_id: _,
 124                retry_after,
 125            } => Err(LanguageModelCompletionError::from_cloud_failure(
 126                upstream_provider,
 127                code,
 128                message,
 129                retry_after.map(Duration::from_secs_f64),
 130            )),
 131        }
 132    }
 133}
 134
 135#[derive(Error, Debug)]
 136pub enum LanguageModelCompletionError {
 137    #[error("prompt too large for context window")]
 138    PromptTooLarge { tokens: Option<u64> },
 139    #[error("missing {provider} API key")]
 140    NoApiKey { provider: LanguageModelProviderName },
 141    #[error("{provider}'s API rate limit exceeded")]
 142    RateLimitExceeded {
 143        provider: LanguageModelProviderName,
 144        retry_after: Option<Duration>,
 145    },
 146    #[error("{provider}'s API servers are overloaded right now")]
 147    ServerOverloaded {
 148        provider: LanguageModelProviderName,
 149        retry_after: Option<Duration>,
 150    },
 151    #[error("{provider}'s API server reported an internal server error: {message}")]
 152    ApiInternalServerError {
 153        provider: LanguageModelProviderName,
 154        message: String,
 155    },
 156    #[error("{message}")]
 157    UpstreamProviderError {
 158        message: String,
 159        status: StatusCode,
 160        retry_after: Option<Duration>,
 161    },
 162    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
 163    HttpResponseError {
 164        provider: LanguageModelProviderName,
 165        status_code: StatusCode,
 166        message: String,
 167    },
 168
 169    // Client errors
 170    #[error("invalid request format to {provider}'s API: {message}")]
 171    BadRequestFormat {
 172        provider: LanguageModelProviderName,
 173        message: String,
 174    },
 175    #[error("authentication error with {provider}'s API: {message}")]
 176    AuthenticationError {
 177        provider: LanguageModelProviderName,
 178        message: String,
 179    },
 180    #[error("Permission error with {provider}'s API: {message}")]
 181    PermissionError {
 182        provider: LanguageModelProviderName,
 183        message: String,
 184    },
 185    #[error("language model provider API endpoint not found")]
 186    ApiEndpointNotFound { provider: LanguageModelProviderName },
 187    #[error("I/O error reading response from {provider}'s API")]
 188    ApiReadResponseError {
 189        provider: LanguageModelProviderName,
 190        #[source]
 191        error: io::Error,
 192    },
 193    #[error("error serializing request to {provider} API")]
 194    SerializeRequest {
 195        provider: LanguageModelProviderName,
 196        #[source]
 197        error: serde_json::Error,
 198    },
 199    #[error("error building request body to {provider} API")]
 200    BuildRequestBody {
 201        provider: LanguageModelProviderName,
 202        #[source]
 203        error: http::Error,
 204    },
 205    #[error("error sending HTTP request to {provider} API")]
 206    HttpSend {
 207        provider: LanguageModelProviderName,
 208        #[source]
 209        error: anyhow::Error,
 210    },
 211    #[error("error deserializing {provider} API response")]
 212    DeserializeResponse {
 213        provider: LanguageModelProviderName,
 214        #[source]
 215        error: serde_json::Error,
 216    },
 217
 218    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
 219    #[error(transparent)]
 220    Other(#[from] anyhow::Error),
 221}
 222
 223impl LanguageModelCompletionError {
 224    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
 225        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
 226        let upstream_status = error_json
 227            .get("upstream_status")
 228            .and_then(|v| v.as_u64())
 229            .and_then(|status| u16::try_from(status).ok())
 230            .and_then(|status| StatusCode::from_u16(status).ok())?;
 231        let inner_message = error_json
 232            .get("message")
 233            .and_then(|v| v.as_str())
 234            .unwrap_or(message)
 235            .to_string();
 236        Some((upstream_status, inner_message))
 237    }
 238
 239    pub fn from_cloud_failure(
 240        upstream_provider: LanguageModelProviderName,
 241        code: String,
 242        message: String,
 243        retry_after: Option<Duration>,
 244    ) -> Self {
 245        if let Some(tokens) = parse_prompt_too_long(&message) {
 246            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
 247            // to be reported. This is a temporary workaround to handle this in the case where the
 248            // token limit has been exceeded.
 249            Self::PromptTooLarge {
 250                tokens: Some(tokens),
 251            }
 252        } else if code == "upstream_http_error" {
 253            if let Some((upstream_status, inner_message)) =
 254                Self::parse_upstream_error_json(&message)
 255            {
 256                return Self::from_http_status(
 257                    upstream_provider,
 258                    upstream_status,
 259                    inner_message,
 260                    retry_after,
 261                );
 262            }
 263            anyhow!("completion request failed, code: {code}, message: {message}").into()
 264        } else if let Some(status_code) = code
 265            .strip_prefix("upstream_http_")
 266            .and_then(|code| StatusCode::from_str(code).ok())
 267        {
 268            Self::from_http_status(upstream_provider, status_code, message, retry_after)
 269        } else if let Some(status_code) = code
 270            .strip_prefix("http_")
 271            .and_then(|code| StatusCode::from_str(code).ok())
 272        {
 273            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
 274        } else {
 275            anyhow!("completion request failed, code: {code}, message: {message}").into()
 276        }
 277    }
 278
 279    pub fn from_http_status(
 280        provider: LanguageModelProviderName,
 281        status_code: StatusCode,
 282        message: String,
 283        retry_after: Option<Duration>,
 284    ) -> Self {
 285        match status_code {
 286            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
 287            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
 288            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
 289            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
 290            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
 291                tokens: parse_prompt_too_long(&message),
 292            },
 293            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
 294                provider,
 295                retry_after,
 296            },
 297            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
 298            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
 299                provider,
 300                retry_after,
 301            },
 302            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
 303                provider,
 304                retry_after,
 305            },
 306            _ => Self::HttpResponseError {
 307                provider,
 308                status_code,
 309                message,
 310            },
 311        }
 312    }
 313}
 314
 315impl From<AnthropicError> for LanguageModelCompletionError {
 316    fn from(error: AnthropicError) -> Self {
 317        let provider = ANTHROPIC_PROVIDER_NAME;
 318        match error {
 319            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 320            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 321            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
 322            AnthropicError::DeserializeResponse(error) => {
 323                Self::DeserializeResponse { provider, error }
 324            }
 325            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 326            AnthropicError::HttpResponseError {
 327                status_code,
 328                message,
 329            } => Self::HttpResponseError {
 330                provider,
 331                status_code,
 332                message,
 333            },
 334            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
 335                provider,
 336                retry_after: Some(retry_after),
 337            },
 338            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 339                provider,
 340                retry_after,
 341            },
 342            AnthropicError::ApiError(api_error) => api_error.into(),
 343        }
 344    }
 345}
 346
 347impl From<anthropic::ApiError> for LanguageModelCompletionError {
 348    fn from(error: anthropic::ApiError) -> Self {
 349        use anthropic::ApiErrorCode::*;
 350        let provider = ANTHROPIC_PROVIDER_NAME;
 351        match error.code() {
 352            Some(code) => match code {
 353                InvalidRequestError => Self::BadRequestFormat {
 354                    provider,
 355                    message: error.message,
 356                },
 357                AuthenticationError => Self::AuthenticationError {
 358                    provider,
 359                    message: error.message,
 360                },
 361                PermissionError => Self::PermissionError {
 362                    provider,
 363                    message: error.message,
 364                },
 365                NotFoundError => Self::ApiEndpointNotFound { provider },
 366                RequestTooLarge => Self::PromptTooLarge {
 367                    tokens: parse_prompt_too_long(&error.message),
 368                },
 369                RateLimitError => Self::RateLimitExceeded {
 370                    provider,
 371                    retry_after: None,
 372                },
 373                ApiError => Self::ApiInternalServerError {
 374                    provider,
 375                    message: error.message,
 376                },
 377                OverloadedError => Self::ServerOverloaded {
 378                    provider,
 379                    retry_after: None,
 380                },
 381            },
 382            None => Self::Other(error.into()),
 383        }
 384    }
 385}
 386
 387impl From<open_ai::RequestError> for LanguageModelCompletionError {
 388    fn from(error: open_ai::RequestError) -> Self {
 389        match error {
 390            open_ai::RequestError::HttpResponseError {
 391                provider,
 392                status_code,
 393                body,
 394                headers,
 395            } => {
 396                let retry_after = headers
 397                    .get(http::header::RETRY_AFTER)
 398                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
 399                    .map(Duration::from_secs);
 400
 401                Self::from_http_status(provider.into(), status_code, body, retry_after)
 402            }
 403            open_ai::RequestError::Other(e) => Self::Other(e),
 404        }
 405    }
 406}
 407
 408impl From<OpenRouterError> for LanguageModelCompletionError {
 409    fn from(error: OpenRouterError) -> Self {
 410        let provider = LanguageModelProviderName::new("OpenRouter");
 411        match error {
 412            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 413            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 414            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
 415            OpenRouterError::DeserializeResponse(error) => {
 416                Self::DeserializeResponse { provider, error }
 417            }
 418            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 419            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
 420                provider,
 421                retry_after: Some(retry_after),
 422            },
 423            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 424                provider,
 425                retry_after,
 426            },
 427            OpenRouterError::ApiError(api_error) => api_error.into(),
 428        }
 429    }
 430}
 431
 432impl From<open_router::ApiError> for LanguageModelCompletionError {
 433    fn from(error: open_router::ApiError) -> Self {
 434        use open_router::ApiErrorCode::*;
 435        let provider = LanguageModelProviderName::new("OpenRouter");
 436        match error.code {
 437            InvalidRequestError => Self::BadRequestFormat {
 438                provider,
 439                message: error.message,
 440            },
 441            AuthenticationError => Self::AuthenticationError {
 442                provider,
 443                message: error.message,
 444            },
 445            PaymentRequiredError => Self::AuthenticationError {
 446                provider,
 447                message: format!("Payment required: {}", error.message),
 448            },
 449            PermissionError => Self::PermissionError {
 450                provider,
 451                message: error.message,
 452            },
 453            RequestTimedOut => Self::HttpResponseError {
 454                provider,
 455                status_code: StatusCode::REQUEST_TIMEOUT,
 456                message: error.message,
 457            },
 458            RateLimitError => Self::RateLimitExceeded {
 459                provider,
 460                retry_after: None,
 461            },
 462            ApiError => Self::ApiInternalServerError {
 463                provider,
 464                message: error.message,
 465            },
 466            OverloadedError => Self::ServerOverloaded {
 467                provider,
 468                retry_after: None,
 469            },
 470        }
 471    }
 472}
 473
 474#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 475#[serde(rename_all = "snake_case")]
 476pub enum StopReason {
 477    EndTurn,
 478    MaxTokens,
 479    ToolUse,
 480    Refusal,
 481}
 482
 483#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 484pub struct TokenUsage {
 485    #[serde(default, skip_serializing_if = "is_default")]
 486    pub input_tokens: u64,
 487    #[serde(default, skip_serializing_if = "is_default")]
 488    pub output_tokens: u64,
 489    #[serde(default, skip_serializing_if = "is_default")]
 490    pub cache_creation_input_tokens: u64,
 491    #[serde(default, skip_serializing_if = "is_default")]
 492    pub cache_read_input_tokens: u64,
 493}
 494
 495impl TokenUsage {
 496    pub fn total_tokens(&self) -> u64 {
 497        self.input_tokens
 498            + self.output_tokens
 499            + self.cache_read_input_tokens
 500            + self.cache_creation_input_tokens
 501    }
 502}
 503
 504impl Add<TokenUsage> for TokenUsage {
 505    type Output = Self;
 506
 507    fn add(self, other: Self) -> Self {
 508        Self {
 509            input_tokens: self.input_tokens + other.input_tokens,
 510            output_tokens: self.output_tokens + other.output_tokens,
 511            cache_creation_input_tokens: self.cache_creation_input_tokens
 512                + other.cache_creation_input_tokens,
 513            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
 514        }
 515    }
 516}
 517
 518impl Sub<TokenUsage> for TokenUsage {
 519    type Output = Self;
 520
 521    fn sub(self, other: Self) -> Self {
 522        Self {
 523            input_tokens: self.input_tokens - other.input_tokens,
 524            output_tokens: self.output_tokens - other.output_tokens,
 525            cache_creation_input_tokens: self.cache_creation_input_tokens
 526                - other.cache_creation_input_tokens,
 527            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
 528        }
 529    }
 530}
 531
 532#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 533pub struct LanguageModelToolUseId(Arc<str>);
 534
 535impl fmt::Display for LanguageModelToolUseId {
 536    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 537        write!(f, "{}", self.0)
 538    }
 539}
 540
 541impl<T> From<T> for LanguageModelToolUseId
 542where
 543    T: Into<Arc<str>>,
 544{
 545    fn from(value: T) -> Self {
 546        Self(value.into())
 547    }
 548}
 549
 550#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 551pub struct LanguageModelToolUse {
 552    pub id: LanguageModelToolUseId,
 553    pub name: Arc<str>,
 554    pub raw_input: String,
 555    pub input: serde_json::Value,
 556    pub is_input_complete: bool,
 557    /// Thought signature the model sent us. Some models require that this
 558    /// signature be preserved and sent back in conversation history for validation.
 559    pub thought_signature: Option<String>,
 560}
 561
 562pub struct LanguageModelTextStream {
 563    pub message_id: Option<String>,
 564    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 565    // Has complete token usage after the stream has finished
 566    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 567}
 568
 569impl Default for LanguageModelTextStream {
 570    fn default() -> Self {
 571        Self {
 572            message_id: None,
 573            stream: Box::pin(futures::stream::empty()),
 574            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 575        }
 576    }
 577}
 578
 579pub trait LanguageModel: Send + Sync {
 580    fn id(&self) -> LanguageModelId;
 581    fn name(&self) -> LanguageModelName;
 582    fn provider_id(&self) -> LanguageModelProviderId;
 583    fn provider_name(&self) -> LanguageModelProviderName;
 584    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 585        self.provider_id()
 586    }
 587    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 588        self.provider_name()
 589    }
 590
 591    fn telemetry_id(&self) -> String;
 592
 593    fn api_key(&self, _cx: &App) -> Option<String> {
 594        None
 595    }
 596
 597    /// Whether this model supports images
 598    fn supports_images(&self) -> bool;
 599
 600    /// Whether this model supports tools.
 601    fn supports_tools(&self) -> bool;
 602
 603    /// Whether this model supports choosing which tool to use.
 604    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
 605
 606    /// Returns whether this model supports "burn mode";
 607    fn supports_burn_mode(&self) -> bool {
 608        false
 609    }
 610
 611    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 612        LanguageModelToolSchemaFormat::JsonSchema
 613    }
 614
 615    fn max_token_count(&self) -> u64;
 616    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
 617    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 618        None
 619    }
 620    fn max_output_tokens(&self) -> Option<u64> {
 621        None
 622    }
 623
 624    fn count_tokens(
 625        &self,
 626        request: LanguageModelRequest,
 627        cx: &App,
 628    ) -> BoxFuture<'static, Result<u64>>;
 629
 630    fn stream_completion(
 631        &self,
 632        request: LanguageModelRequest,
 633        cx: &AsyncApp,
 634    ) -> BoxFuture<
 635        'static,
 636        Result<
 637            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 638            LanguageModelCompletionError,
 639        >,
 640    >;
 641
 642    fn stream_completion_text(
 643        &self,
 644        request: LanguageModelRequest,
 645        cx: &AsyncApp,
 646    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
 647        let future = self.stream_completion(request, cx);
 648
 649        async move {
 650            let events = future.await?;
 651            let mut events = events.fuse();
 652            let mut message_id = None;
 653            let mut first_item_text = None;
 654            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 655
 656            if let Some(first_event) = events.next().await {
 657                match first_event {
 658                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
 659                        message_id = Some(id);
 660                    }
 661                    Ok(LanguageModelCompletionEvent::Text(text)) => {
 662                        first_item_text = Some(text);
 663                    }
 664                    _ => (),
 665                }
 666            }
 667
 668            let stream = futures::stream::iter(first_item_text.map(Ok))
 669                .chain(events.filter_map({
 670                    let last_token_usage = last_token_usage.clone();
 671                    move |result| {
 672                        let last_token_usage = last_token_usage.clone();
 673                        async move {
 674                            match result {
 675                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
 676                                Ok(LanguageModelCompletionEvent::Started) => None,
 677                                Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None,
 678                                Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
 679                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
 680                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
 681                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
 682                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
 683                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
 684                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
 685                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 686                                    ..
 687                                }) => None,
 688                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
 689                                    *last_token_usage.lock() = token_usage;
 690                                    None
 691                                }
 692                                Err(err) => Some(Err(err)),
 693                            }
 694                        }
 695                    }
 696                }))
 697                .boxed();
 698
 699            Ok(LanguageModelTextStream {
 700                message_id,
 701                stream,
 702                last_token_usage,
 703            })
 704        }
 705        .boxed()
 706    }
 707
 708    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 709        None
 710    }
 711
 712    #[cfg(any(test, feature = "test-support"))]
 713    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
 714        unimplemented!()
 715    }
 716}
 717
 718pub trait LanguageModelExt: LanguageModel {
 719    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
 720        match mode {
 721            CompletionMode::Normal => self.max_token_count(),
 722            CompletionMode::Max => self
 723                .max_token_count_in_burn_mode()
 724                .unwrap_or_else(|| self.max_token_count()),
 725        }
 726    }
 727}
 728impl LanguageModelExt for dyn LanguageModel {}
 729
 730/// An error that occurred when trying to authenticate the language model provider.
 731#[derive(Debug, Error)]
 732pub enum AuthenticateError {
 733    #[error("connection refused")]
 734    ConnectionRefused,
 735    #[error("credentials not found")]
 736    CredentialsNotFound,
 737    #[error(transparent)]
 738    Other(#[from] anyhow::Error),
 739}
 740
 741pub trait LanguageModelProvider: 'static {
 742    fn id(&self) -> LanguageModelProviderId;
 743    fn name(&self) -> LanguageModelProviderName;
 744    fn icon(&self) -> IconName {
 745        IconName::ZedAssistant
 746    }
 747    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 748    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 749    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
 750    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 751        Vec::new()
 752    }
 753    fn is_authenticated(&self, cx: &App) -> bool;
 754    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
 755    fn configuration_view(
 756        &self,
 757        target_agent: ConfigurationViewTargetAgent,
 758        window: &mut Window,
 759        cx: &mut App,
 760    ) -> AnyView;
 761    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
 762}
 763
 764#[derive(Default, Clone)]
 765pub enum ConfigurationViewTargetAgent {
 766    #[default]
 767    ZedAgent,
 768    Other(SharedString),
 769}
 770
 771#[derive(PartialEq, Eq)]
 772pub enum LanguageModelProviderTosView {
 773    /// When there are some past interactions in the Agent Panel.
 774    ThreadEmptyState,
 775    /// When there are no past interactions in the Agent Panel.
 776    ThreadFreshStart,
 777    TextThreadPopup,
 778    Configuration,
 779}
 780
 781pub trait LanguageModelProviderState: 'static {
 782    type ObservableEntity;
 783
 784    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
 785
 786    fn subscribe<T: 'static>(
 787        &self,
 788        cx: &mut gpui::Context<T>,
 789        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
 790    ) -> Option<gpui::Subscription> {
 791        let entity = self.observable_entity()?;
 792        Some(cx.observe(&entity, move |this, _, cx| {
 793            callback(this, cx);
 794        }))
 795    }
 796}
 797
 798#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
 799pub struct LanguageModelId(pub SharedString);
 800
 801#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 802pub struct LanguageModelName(pub SharedString);
 803
 804#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 805pub struct LanguageModelProviderId(pub SharedString);
 806
 807#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 808pub struct LanguageModelProviderName(pub SharedString);
 809
 810impl LanguageModelProviderId {
 811    pub const fn new(id: &'static str) -> Self {
 812        Self(SharedString::new_static(id))
 813    }
 814}
 815
 816impl LanguageModelProviderName {
 817    pub const fn new(id: &'static str) -> Self {
 818        Self(SharedString::new_static(id))
 819    }
 820}
 821
 822impl fmt::Display for LanguageModelProviderId {
 823    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 824        write!(f, "{}", self.0)
 825    }
 826}
 827
 828impl fmt::Display for LanguageModelProviderName {
 829    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 830        write!(f, "{}", self.0)
 831    }
 832}
 833
 834impl From<String> for LanguageModelId {
 835    fn from(value: String) -> Self {
 836        Self(SharedString::from(value))
 837    }
 838}
 839
 840impl From<String> for LanguageModelName {
 841    fn from(value: String) -> Self {
 842        Self(SharedString::from(value))
 843    }
 844}
 845
 846impl From<String> for LanguageModelProviderId {
 847    fn from(value: String) -> Self {
 848        Self(SharedString::from(value))
 849    }
 850}
 851
 852impl From<String> for LanguageModelProviderName {
 853    fn from(value: String) -> Self {
 854        Self(SharedString::from(value))
 855    }
 856}
 857
 858impl From<Arc<str>> for LanguageModelProviderId {
 859    fn from(value: Arc<str>) -> Self {
 860        Self(SharedString::from(value))
 861    }
 862}
 863
 864impl From<Arc<str>> for LanguageModelProviderName {
 865    fn from(value: Arc<str>) -> Self {
 866        Self(SharedString::from(value))
 867    }
 868}
 869
 870#[cfg(test)]
 871mod tests {
 872    use super::*;
 873
 874    #[test]
 875    fn test_from_cloud_failure_with_upstream_http_error() {
 876        let error = LanguageModelCompletionError::from_cloud_failure(
 877            String::from("anthropic").into(),
 878            "upstream_http_error".to_string(),
 879            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 880            None,
 881        );
 882
 883        match error {
 884            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 885                assert_eq!(provider.0, "anthropic");
 886            }
 887            _ => panic!(
 888                "Expected ServerOverloaded error for 503 status, got: {:?}",
 889                error
 890            ),
 891        }
 892
 893        let error = LanguageModelCompletionError::from_cloud_failure(
 894            String::from("anthropic").into(),
 895            "upstream_http_error".to_string(),
 896            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
 897            None,
 898        );
 899
 900        match error {
 901            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 902                assert_eq!(provider.0, "anthropic");
 903                assert_eq!(message, "Internal server error");
 904            }
 905            _ => panic!(
 906                "Expected ApiInternalServerError for 500 status, got: {:?}",
 907                error
 908            ),
 909        }
 910    }
 911
 912    #[test]
 913    fn test_from_cloud_failure_with_standard_format() {
 914        let error = LanguageModelCompletionError::from_cloud_failure(
 915            String::from("anthropic").into(),
 916            "upstream_http_503".to_string(),
 917            "Service unavailable".to_string(),
 918            None,
 919        );
 920
 921        match error {
 922            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 923                assert_eq!(provider.0, "anthropic");
 924            }
 925            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
 926        }
 927    }
 928
 929    #[test]
 930    fn test_upstream_http_error_connection_timeout() {
 931        let error = LanguageModelCompletionError::from_cloud_failure(
 932            String::from("anthropic").into(),
 933            "upstream_http_error".to_string(),
 934            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 935            None,
 936        );
 937
 938        match error {
 939            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 940                assert_eq!(provider.0, "anthropic");
 941            }
 942            _ => panic!(
 943                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
 944                error
 945            ),
 946        }
 947
 948        let error = LanguageModelCompletionError::from_cloud_failure(
 949            String::from("anthropic").into(),
 950            "upstream_http_error".to_string(),
 951            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
 952            None,
 953        );
 954
 955        match error {
 956            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 957                assert_eq!(provider.0, "anthropic");
 958                assert_eq!(
 959                    message,
 960                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
 961                );
 962            }
 963            _ => panic!(
 964                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
 965                error
 966            ),
 967        }
 968    }
 969
 970    #[test]
 971    fn test_language_model_tool_use_serializes_with_signature() {
 972        use serde_json::json;
 973
 974        let tool_use = LanguageModelToolUse {
 975            id: LanguageModelToolUseId::from("test_id"),
 976            name: "test_tool".into(),
 977            raw_input: json!({"arg": "value"}).to_string(),
 978            input: json!({"arg": "value"}),
 979            is_input_complete: true,
 980            thought_signature: Some("test_signature".to_string()),
 981        };
 982
 983        let serialized = serde_json::to_value(&tool_use).unwrap();
 984
 985        assert_eq!(serialized["id"], "test_id");
 986        assert_eq!(serialized["name"], "test_tool");
 987        assert_eq!(serialized["thought_signature"], "test_signature");
 988    }
 989
 990    #[test]
 991    fn test_language_model_tool_use_deserializes_with_missing_signature() {
 992        use serde_json::json;
 993
 994        let json = json!({
 995            "id": "test_id",
 996            "name": "test_tool",
 997            "raw_input": "{\"arg\":\"value\"}",
 998            "input": {"arg": "value"},
 999            "is_input_complete": true
1000        });
1001
1002        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1003
1004        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1005        assert_eq!(tool_use.name.as_ref(), "test_tool");
1006        assert_eq!(tool_use.thought_signature, None);
1007    }
1008
1009    #[test]
1010    fn test_language_model_tool_use_round_trip_with_signature() {
1011        use serde_json::json;
1012
1013        let original = LanguageModelToolUse {
1014            id: LanguageModelToolUseId::from("round_trip_id"),
1015            name: "round_trip_tool".into(),
1016            raw_input: json!({"key": "value"}).to_string(),
1017            input: json!({"key": "value"}),
1018            is_input_complete: true,
1019            thought_signature: Some("round_trip_sig".to_string()),
1020        };
1021
1022        let serialized = serde_json::to_value(&original).unwrap();
1023        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1024
1025        assert_eq!(deserialized.id, original.id);
1026        assert_eq!(deserialized.name, original.name);
1027        assert_eq!(deserialized.thought_signature, original.thought_signature);
1028    }
1029
1030    #[test]
1031    fn test_language_model_tool_use_round_trip_without_signature() {
1032        use serde_json::json;
1033
1034        let original = LanguageModelToolUse {
1035            id: LanguageModelToolUseId::from("no_sig_id"),
1036            name: "no_sig_tool".into(),
1037            raw_input: json!({"key": "value"}).to_string(),
1038            input: json!({"key": "value"}),
1039            is_input_complete: true,
1040            thought_signature: None,
1041        };
1042
1043        let serialized = serde_json::to_value(&original).unwrap();
1044        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1045
1046        assert_eq!(deserialized.id, original.id);
1047        assert_eq!(deserialized.name, original.name);
1048        assert_eq!(deserialized.thought_signature, None);
1049    }
1050}