language_model.rs

   1mod model;
   2mod rate_limiter;
   3mod registry;
   4mod request;
   5mod role;
   6mod telemetry;
   7pub mod tool_schema;
   8
   9#[cfg(any(test, feature = "test-support"))]
  10pub mod fake_provider;
  11
  12use anthropic::{AnthropicError, parse_prompt_too_long};
  13use anyhow::{Result, anyhow};
  14use client::Client;
  15use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
  16use futures::FutureExt;
  17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
  18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
  19use http_client::{StatusCode, http};
  20use icons::IconName;
  21use open_router::OpenRouterError;
  22use parking_lot::Mutex;
  23use serde::{Deserialize, Serialize};
  24pub use settings::LanguageModelCacheConfiguration;
  25use std::ops::{Add, Sub};
  26use std::str::FromStr;
  27use std::sync::Arc;
  28use std::time::Duration;
  29use std::{fmt, io};
  30use thiserror::Error;
  31use util::serde::is_default;
  32
  33pub use crate::model::*;
  34pub use crate::rate_limiter::*;
  35pub use crate::registry::*;
  36pub use crate::request::*;
  37pub use crate::role::*;
  38pub use crate::telemetry::*;
  39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
  40
  41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
  42    LanguageModelProviderId::new("anthropic");
  43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
  44    LanguageModelProviderName::new("Anthropic");
  45
  46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
  47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
  48    LanguageModelProviderName::new("Google AI");
  49
  50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
  51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
  52    LanguageModelProviderName::new("OpenAI");
  53
  54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
  55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
  56
  57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
  58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
  59    LanguageModelProviderName::new("Zed");
  60
  61pub fn init(client: Arc<Client>, cx: &mut App) {
  62    init_settings(cx);
  63    RefreshLlmTokenListener::register(client, cx);
  64}
  65
  66pub fn init_settings(cx: &mut App) {
  67    registry::init(cx);
  68}
  69
  70/// A completion event from a language model.
  71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
  72pub enum LanguageModelCompletionEvent {
  73    Queued {
  74        position: usize,
  75    },
  76    Started,
  77    UsageUpdated {
  78        amount: usize,
  79        limit: UsageLimit,
  80    },
  81    ToolUseLimitReached,
  82    Stop(StopReason),
  83    Text(String),
  84    Thinking {
  85        text: String,
  86        signature: Option<String>,
  87    },
  88    RedactedThinking {
  89        data: String,
  90    },
  91    ToolUse(LanguageModelToolUse),
  92    ToolUseJsonParseError {
  93        id: LanguageModelToolUseId,
  94        tool_name: Arc<str>,
  95        raw_input: Arc<str>,
  96        json_parse_error: String,
  97    },
  98    StartMessage {
  99        message_id: String,
 100    },
 101    ReasoningDetails(serde_json::Value),
 102    UsageUpdate(TokenUsage),
 103}
 104
 105impl LanguageModelCompletionEvent {
 106    pub fn from_completion_request_status(
 107        status: CompletionRequestStatus,
 108        upstream_provider: LanguageModelProviderName,
 109    ) -> Result<Self, LanguageModelCompletionError> {
 110        match status {
 111            CompletionRequestStatus::Queued { position } => {
 112                Ok(LanguageModelCompletionEvent::Queued { position })
 113            }
 114            CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
 115            CompletionRequestStatus::UsageUpdated { amount, limit } => {
 116                Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit })
 117            }
 118            CompletionRequestStatus::ToolUseLimitReached => {
 119                Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
 120            }
 121            CompletionRequestStatus::Failed {
 122                code,
 123                message,
 124                request_id: _,
 125                retry_after,
 126            } => Err(LanguageModelCompletionError::from_cloud_failure(
 127                upstream_provider,
 128                code,
 129                message,
 130                retry_after.map(Duration::from_secs_f64),
 131            )),
 132        }
 133    }
 134}
 135
 136#[derive(Error, Debug)]
 137pub enum LanguageModelCompletionError {
 138    #[error("prompt too large for context window")]
 139    PromptTooLarge { tokens: Option<u64> },
 140    #[error("missing {provider} API key")]
 141    NoApiKey { provider: LanguageModelProviderName },
 142    #[error("{provider}'s API rate limit exceeded")]
 143    RateLimitExceeded {
 144        provider: LanguageModelProviderName,
 145        retry_after: Option<Duration>,
 146    },
 147    #[error("{provider}'s API servers are overloaded right now")]
 148    ServerOverloaded {
 149        provider: LanguageModelProviderName,
 150        retry_after: Option<Duration>,
 151    },
 152    #[error("{provider}'s API server reported an internal server error: {message}")]
 153    ApiInternalServerError {
 154        provider: LanguageModelProviderName,
 155        message: String,
 156    },
 157    #[error("{message}")]
 158    UpstreamProviderError {
 159        message: String,
 160        status: StatusCode,
 161        retry_after: Option<Duration>,
 162    },
 163    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
 164    HttpResponseError {
 165        provider: LanguageModelProviderName,
 166        status_code: StatusCode,
 167        message: String,
 168    },
 169
 170    // Client errors
 171    #[error("invalid request format to {provider}'s API: {message}")]
 172    BadRequestFormat {
 173        provider: LanguageModelProviderName,
 174        message: String,
 175    },
 176    #[error("authentication error with {provider}'s API: {message}")]
 177    AuthenticationError {
 178        provider: LanguageModelProviderName,
 179        message: String,
 180    },
 181    #[error("Permission error with {provider}'s API: {message}")]
 182    PermissionError {
 183        provider: LanguageModelProviderName,
 184        message: String,
 185    },
 186    #[error("language model provider API endpoint not found")]
 187    ApiEndpointNotFound { provider: LanguageModelProviderName },
 188    #[error("I/O error reading response from {provider}'s API")]
 189    ApiReadResponseError {
 190        provider: LanguageModelProviderName,
 191        #[source]
 192        error: io::Error,
 193    },
 194    #[error("error serializing request to {provider} API")]
 195    SerializeRequest {
 196        provider: LanguageModelProviderName,
 197        #[source]
 198        error: serde_json::Error,
 199    },
 200    #[error("error building request body to {provider} API")]
 201    BuildRequestBody {
 202        provider: LanguageModelProviderName,
 203        #[source]
 204        error: http::Error,
 205    },
 206    #[error("error sending HTTP request to {provider} API")]
 207    HttpSend {
 208        provider: LanguageModelProviderName,
 209        #[source]
 210        error: anyhow::Error,
 211    },
 212    #[error("error deserializing {provider} API response")]
 213    DeserializeResponse {
 214        provider: LanguageModelProviderName,
 215        #[source]
 216        error: serde_json::Error,
 217    },
 218
 219    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
 220    #[error(transparent)]
 221    Other(#[from] anyhow::Error),
 222}
 223
 224impl LanguageModelCompletionError {
 225    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
 226        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
 227        let upstream_status = error_json
 228            .get("upstream_status")
 229            .and_then(|v| v.as_u64())
 230            .and_then(|status| u16::try_from(status).ok())
 231            .and_then(|status| StatusCode::from_u16(status).ok())?;
 232        let inner_message = error_json
 233            .get("message")
 234            .and_then(|v| v.as_str())
 235            .unwrap_or(message)
 236            .to_string();
 237        Some((upstream_status, inner_message))
 238    }
 239
 240    pub fn from_cloud_failure(
 241        upstream_provider: LanguageModelProviderName,
 242        code: String,
 243        message: String,
 244        retry_after: Option<Duration>,
 245    ) -> Self {
 246        if let Some(tokens) = parse_prompt_too_long(&message) {
 247            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
 248            // to be reported. This is a temporary workaround to handle this in the case where the
 249            // token limit has been exceeded.
 250            Self::PromptTooLarge {
 251                tokens: Some(tokens),
 252            }
 253        } else if code == "upstream_http_error" {
 254            if let Some((upstream_status, inner_message)) =
 255                Self::parse_upstream_error_json(&message)
 256            {
 257                return Self::from_http_status(
 258                    upstream_provider,
 259                    upstream_status,
 260                    inner_message,
 261                    retry_after,
 262                );
 263            }
 264            anyhow!("completion request failed, code: {code}, message: {message}").into()
 265        } else if let Some(status_code) = code
 266            .strip_prefix("upstream_http_")
 267            .and_then(|code| StatusCode::from_str(code).ok())
 268        {
 269            Self::from_http_status(upstream_provider, status_code, message, retry_after)
 270        } else if let Some(status_code) = code
 271            .strip_prefix("http_")
 272            .and_then(|code| StatusCode::from_str(code).ok())
 273        {
 274            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
 275        } else {
 276            anyhow!("completion request failed, code: {code}, message: {message}").into()
 277        }
 278    }
 279
 280    pub fn from_http_status(
 281        provider: LanguageModelProviderName,
 282        status_code: StatusCode,
 283        message: String,
 284        retry_after: Option<Duration>,
 285    ) -> Self {
 286        match status_code {
 287            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
 288            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
 289            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
 290            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
 291            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
 292                tokens: parse_prompt_too_long(&message),
 293            },
 294            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
 295                provider,
 296                retry_after,
 297            },
 298            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
 299            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
 300                provider,
 301                retry_after,
 302            },
 303            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
 304                provider,
 305                retry_after,
 306            },
 307            _ => Self::HttpResponseError {
 308                provider,
 309                status_code,
 310                message,
 311            },
 312        }
 313    }
 314}
 315
 316impl From<AnthropicError> for LanguageModelCompletionError {
 317    fn from(error: AnthropicError) -> Self {
 318        let provider = ANTHROPIC_PROVIDER_NAME;
 319        match error {
 320            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 321            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 322            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
 323            AnthropicError::DeserializeResponse(error) => {
 324                Self::DeserializeResponse { provider, error }
 325            }
 326            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 327            AnthropicError::HttpResponseError {
 328                status_code,
 329                message,
 330            } => Self::HttpResponseError {
 331                provider,
 332                status_code,
 333                message,
 334            },
 335            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
 336                provider,
 337                retry_after: Some(retry_after),
 338            },
 339            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 340                provider,
 341                retry_after,
 342            },
 343            AnthropicError::ApiError(api_error) => api_error.into(),
 344        }
 345    }
 346}
 347
 348impl From<anthropic::ApiError> for LanguageModelCompletionError {
 349    fn from(error: anthropic::ApiError) -> Self {
 350        use anthropic::ApiErrorCode::*;
 351        let provider = ANTHROPIC_PROVIDER_NAME;
 352        match error.code() {
 353            Some(code) => match code {
 354                InvalidRequestError => Self::BadRequestFormat {
 355                    provider,
 356                    message: error.message,
 357                },
 358                AuthenticationError => Self::AuthenticationError {
 359                    provider,
 360                    message: error.message,
 361                },
 362                PermissionError => Self::PermissionError {
 363                    provider,
 364                    message: error.message,
 365                },
 366                NotFoundError => Self::ApiEndpointNotFound { provider },
 367                RequestTooLarge => Self::PromptTooLarge {
 368                    tokens: parse_prompt_too_long(&error.message),
 369                },
 370                RateLimitError => Self::RateLimitExceeded {
 371                    provider,
 372                    retry_after: None,
 373                },
 374                ApiError => Self::ApiInternalServerError {
 375                    provider,
 376                    message: error.message,
 377                },
 378                OverloadedError => Self::ServerOverloaded {
 379                    provider,
 380                    retry_after: None,
 381                },
 382            },
 383            None => Self::Other(error.into()),
 384        }
 385    }
 386}
 387
 388impl From<open_ai::RequestError> for LanguageModelCompletionError {
 389    fn from(error: open_ai::RequestError) -> Self {
 390        match error {
 391            open_ai::RequestError::HttpResponseError {
 392                provider,
 393                status_code,
 394                body,
 395                headers,
 396            } => {
 397                let retry_after = headers
 398                    .get(http::header::RETRY_AFTER)
 399                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
 400                    .map(Duration::from_secs);
 401
 402                Self::from_http_status(provider.into(), status_code, body, retry_after)
 403            }
 404            open_ai::RequestError::Other(e) => Self::Other(e),
 405        }
 406    }
 407}
 408
 409impl From<OpenRouterError> for LanguageModelCompletionError {
 410    fn from(error: OpenRouterError) -> Self {
 411        let provider = LanguageModelProviderName::new("OpenRouter");
 412        match error {
 413            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 414            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 415            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
 416            OpenRouterError::DeserializeResponse(error) => {
 417                Self::DeserializeResponse { provider, error }
 418            }
 419            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 420            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
 421                provider,
 422                retry_after: Some(retry_after),
 423            },
 424            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 425                provider,
 426                retry_after,
 427            },
 428            OpenRouterError::ApiError(api_error) => api_error.into(),
 429        }
 430    }
 431}
 432
 433impl From<open_router::ApiError> for LanguageModelCompletionError {
 434    fn from(error: open_router::ApiError) -> Self {
 435        use open_router::ApiErrorCode::*;
 436        let provider = LanguageModelProviderName::new("OpenRouter");
 437        match error.code {
 438            InvalidRequestError => Self::BadRequestFormat {
 439                provider,
 440                message: error.message,
 441            },
 442            AuthenticationError => Self::AuthenticationError {
 443                provider,
 444                message: error.message,
 445            },
 446            PaymentRequiredError => Self::AuthenticationError {
 447                provider,
 448                message: format!("Payment required: {}", error.message),
 449            },
 450            PermissionError => Self::PermissionError {
 451                provider,
 452                message: error.message,
 453            },
 454            RequestTimedOut => Self::HttpResponseError {
 455                provider,
 456                status_code: StatusCode::REQUEST_TIMEOUT,
 457                message: error.message,
 458            },
 459            RateLimitError => Self::RateLimitExceeded {
 460                provider,
 461                retry_after: None,
 462            },
 463            ApiError => Self::ApiInternalServerError {
 464                provider,
 465                message: error.message,
 466            },
 467            OverloadedError => Self::ServerOverloaded {
 468                provider,
 469                retry_after: None,
 470            },
 471        }
 472    }
 473}
 474
 475#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 476#[serde(rename_all = "snake_case")]
 477pub enum StopReason {
 478    EndTurn,
 479    MaxTokens,
 480    ToolUse,
 481    Refusal,
 482}
 483
 484#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 485pub struct TokenUsage {
 486    #[serde(default, skip_serializing_if = "is_default")]
 487    pub input_tokens: u64,
 488    #[serde(default, skip_serializing_if = "is_default")]
 489    pub output_tokens: u64,
 490    #[serde(default, skip_serializing_if = "is_default")]
 491    pub cache_creation_input_tokens: u64,
 492    #[serde(default, skip_serializing_if = "is_default")]
 493    pub cache_read_input_tokens: u64,
 494}
 495
 496impl TokenUsage {
 497    pub fn total_tokens(&self) -> u64 {
 498        self.input_tokens
 499            + self.output_tokens
 500            + self.cache_read_input_tokens
 501            + self.cache_creation_input_tokens
 502    }
 503}
 504
 505impl Add<TokenUsage> for TokenUsage {
 506    type Output = Self;
 507
 508    fn add(self, other: Self) -> Self {
 509        Self {
 510            input_tokens: self.input_tokens + other.input_tokens,
 511            output_tokens: self.output_tokens + other.output_tokens,
 512            cache_creation_input_tokens: self.cache_creation_input_tokens
 513                + other.cache_creation_input_tokens,
 514            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
 515        }
 516    }
 517}
 518
 519impl Sub<TokenUsage> for TokenUsage {
 520    type Output = Self;
 521
 522    fn sub(self, other: Self) -> Self {
 523        Self {
 524            input_tokens: self.input_tokens - other.input_tokens,
 525            output_tokens: self.output_tokens - other.output_tokens,
 526            cache_creation_input_tokens: self.cache_creation_input_tokens
 527                - other.cache_creation_input_tokens,
 528            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
 529        }
 530    }
 531}
 532
 533#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 534pub struct LanguageModelToolUseId(Arc<str>);
 535
 536impl fmt::Display for LanguageModelToolUseId {
 537    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 538        write!(f, "{}", self.0)
 539    }
 540}
 541
 542impl<T> From<T> for LanguageModelToolUseId
 543where
 544    T: Into<Arc<str>>,
 545{
 546    fn from(value: T) -> Self {
 547        Self(value.into())
 548    }
 549}
 550
 551#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 552pub struct LanguageModelToolUse {
 553    pub id: LanguageModelToolUseId,
 554    pub name: Arc<str>,
 555    pub raw_input: String,
 556    pub input: serde_json::Value,
 557    pub is_input_complete: bool,
 558    /// Thought signature the model sent us. Some models require that this
 559    /// signature be preserved and sent back in conversation history for validation.
 560    pub thought_signature: Option<String>,
 561}
 562
 563pub struct LanguageModelTextStream {
 564    pub message_id: Option<String>,
 565    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 566    // Has complete token usage after the stream has finished
 567    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 568}
 569
 570impl Default for LanguageModelTextStream {
 571    fn default() -> Self {
 572        Self {
 573            message_id: None,
 574            stream: Box::pin(futures::stream::empty()),
 575            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 576        }
 577    }
 578}
 579
 580pub trait LanguageModel: Send + Sync {
 581    fn id(&self) -> LanguageModelId;
 582    fn name(&self) -> LanguageModelName;
 583    fn provider_id(&self) -> LanguageModelProviderId;
 584    fn provider_name(&self) -> LanguageModelProviderName;
 585    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 586        self.provider_id()
 587    }
 588    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 589        self.provider_name()
 590    }
 591
 592    fn telemetry_id(&self) -> String;
 593
 594    fn api_key(&self, _cx: &App) -> Option<String> {
 595        None
 596    }
 597
 598    /// Whether this model supports images
 599    fn supports_images(&self) -> bool;
 600
 601    /// Whether this model supports tools.
 602    fn supports_tools(&self) -> bool;
 603
 604    /// Whether this model supports choosing which tool to use.
 605    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
 606
 607    /// Returns whether this model supports "burn mode";
 608    fn supports_burn_mode(&self) -> bool {
 609        false
 610    }
 611
 612    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 613        LanguageModelToolSchemaFormat::JsonSchema
 614    }
 615
 616    fn max_token_count(&self) -> u64;
 617    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
 618    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 619        None
 620    }
 621    fn max_output_tokens(&self) -> Option<u64> {
 622        None
 623    }
 624
 625    fn count_tokens(
 626        &self,
 627        request: LanguageModelRequest,
 628        cx: &App,
 629    ) -> BoxFuture<'static, Result<u64>>;
 630
 631    fn stream_completion(
 632        &self,
 633        request: LanguageModelRequest,
 634        cx: &AsyncApp,
 635    ) -> BoxFuture<
 636        'static,
 637        Result<
 638            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 639            LanguageModelCompletionError,
 640        >,
 641    >;
 642
 643    fn stream_completion_text(
 644        &self,
 645        request: LanguageModelRequest,
 646        cx: &AsyncApp,
 647    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
 648        let future = self.stream_completion(request, cx);
 649
 650        async move {
 651            let events = future.await?;
 652            let mut events = events.fuse();
 653            let mut message_id = None;
 654            let mut first_item_text = None;
 655            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 656
 657            if let Some(first_event) = events.next().await {
 658                match first_event {
 659                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
 660                        message_id = Some(id);
 661                    }
 662                    Ok(LanguageModelCompletionEvent::Text(text)) => {
 663                        first_item_text = Some(text);
 664                    }
 665                    _ => (),
 666                }
 667            }
 668
 669            let stream = futures::stream::iter(first_item_text.map(Ok))
 670                .chain(events.filter_map({
 671                    let last_token_usage = last_token_usage.clone();
 672                    move |result| {
 673                        let last_token_usage = last_token_usage.clone();
 674                        async move {
 675                            match result {
 676                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
 677                                Ok(LanguageModelCompletionEvent::Started) => None,
 678                                Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None,
 679                                Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
 680                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
 681                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
 682                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
 683                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
 684                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
 685                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
 686                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
 687                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 688                                    ..
 689                                }) => None,
 690                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
 691                                    *last_token_usage.lock() = token_usage;
 692                                    None
 693                                }
 694                                Err(err) => Some(Err(err)),
 695                            }
 696                        }
 697                    }
 698                }))
 699                .boxed();
 700
 701            Ok(LanguageModelTextStream {
 702                message_id,
 703                stream,
 704                last_token_usage,
 705            })
 706        }
 707        .boxed()
 708    }
 709
 710    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 711        None
 712    }
 713
 714    #[cfg(any(test, feature = "test-support"))]
 715    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
 716        unimplemented!()
 717    }
 718}
 719
 720pub trait LanguageModelExt: LanguageModel {
 721    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
 722        match mode {
 723            CompletionMode::Normal => self.max_token_count(),
 724            CompletionMode::Max => self
 725                .max_token_count_in_burn_mode()
 726                .unwrap_or_else(|| self.max_token_count()),
 727        }
 728    }
 729}
 730impl LanguageModelExt for dyn LanguageModel {}
 731
 732/// An error that occurred when trying to authenticate the language model provider.
 733#[derive(Debug, Error)]
 734pub enum AuthenticateError {
 735    #[error("connection refused")]
 736    ConnectionRefused,
 737    #[error("credentials not found")]
 738    CredentialsNotFound,
 739    #[error(transparent)]
 740    Other(#[from] anyhow::Error),
 741}
 742
 743pub trait LanguageModelProvider: 'static {
 744    fn id(&self) -> LanguageModelProviderId;
 745    fn name(&self) -> LanguageModelProviderName;
 746    fn icon(&self) -> IconName {
 747        IconName::ZedAssistant
 748    }
 749    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 750    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 751    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
 752    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 753        Vec::new()
 754    }
 755    fn is_authenticated(&self, cx: &App) -> bool;
 756    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
 757    fn configuration_view(
 758        &self,
 759        target_agent: ConfigurationViewTargetAgent,
 760        window: &mut Window,
 761        cx: &mut App,
 762    ) -> AnyView;
 763    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
 764}
 765
 766#[derive(Default, Clone)]
 767pub enum ConfigurationViewTargetAgent {
 768    #[default]
 769    ZedAgent,
 770    Other(SharedString),
 771}
 772
 773#[derive(PartialEq, Eq)]
 774pub enum LanguageModelProviderTosView {
 775    /// When there are some past interactions in the Agent Panel.
 776    ThreadEmptyState,
 777    /// When there are no past interactions in the Agent Panel.
 778    ThreadFreshStart,
 779    TextThreadPopup,
 780    Configuration,
 781}
 782
 783pub trait LanguageModelProviderState: 'static {
 784    type ObservableEntity;
 785
 786    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
 787
 788    fn subscribe<T: 'static>(
 789        &self,
 790        cx: &mut gpui::Context<T>,
 791        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
 792    ) -> Option<gpui::Subscription> {
 793        let entity = self.observable_entity()?;
 794        Some(cx.observe(&entity, move |this, _, cx| {
 795            callback(this, cx);
 796        }))
 797    }
 798}
 799
 800#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
 801pub struct LanguageModelId(pub SharedString);
 802
 803#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 804pub struct LanguageModelName(pub SharedString);
 805
 806#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 807pub struct LanguageModelProviderId(pub SharedString);
 808
 809#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 810pub struct LanguageModelProviderName(pub SharedString);
 811
 812impl LanguageModelProviderId {
 813    pub const fn new(id: &'static str) -> Self {
 814        Self(SharedString::new_static(id))
 815    }
 816}
 817
 818impl LanguageModelProviderName {
 819    pub const fn new(id: &'static str) -> Self {
 820        Self(SharedString::new_static(id))
 821    }
 822}
 823
 824impl fmt::Display for LanguageModelProviderId {
 825    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 826        write!(f, "{}", self.0)
 827    }
 828}
 829
 830impl fmt::Display for LanguageModelProviderName {
 831    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 832        write!(f, "{}", self.0)
 833    }
 834}
 835
 836impl From<String> for LanguageModelId {
 837    fn from(value: String) -> Self {
 838        Self(SharedString::from(value))
 839    }
 840}
 841
 842impl From<String> for LanguageModelName {
 843    fn from(value: String) -> Self {
 844        Self(SharedString::from(value))
 845    }
 846}
 847
 848impl From<String> for LanguageModelProviderId {
 849    fn from(value: String) -> Self {
 850        Self(SharedString::from(value))
 851    }
 852}
 853
 854impl From<String> for LanguageModelProviderName {
 855    fn from(value: String) -> Self {
 856        Self(SharedString::from(value))
 857    }
 858}
 859
 860impl From<Arc<str>> for LanguageModelProviderId {
 861    fn from(value: Arc<str>) -> Self {
 862        Self(SharedString::from(value))
 863    }
 864}
 865
 866impl From<Arc<str>> for LanguageModelProviderName {
 867    fn from(value: Arc<str>) -> Self {
 868        Self(SharedString::from(value))
 869    }
 870}
 871
 872#[cfg(test)]
 873mod tests {
 874    use super::*;
 875
 876    #[test]
 877    fn test_from_cloud_failure_with_upstream_http_error() {
 878        let error = LanguageModelCompletionError::from_cloud_failure(
 879            String::from("anthropic").into(),
 880            "upstream_http_error".to_string(),
 881            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 882            None,
 883        );
 884
 885        match error {
 886            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 887                assert_eq!(provider.0, "anthropic");
 888            }
 889            _ => panic!(
 890                "Expected ServerOverloaded error for 503 status, got: {:?}",
 891                error
 892            ),
 893        }
 894
 895        let error = LanguageModelCompletionError::from_cloud_failure(
 896            String::from("anthropic").into(),
 897            "upstream_http_error".to_string(),
 898            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
 899            None,
 900        );
 901
 902        match error {
 903            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 904                assert_eq!(provider.0, "anthropic");
 905                assert_eq!(message, "Internal server error");
 906            }
 907            _ => panic!(
 908                "Expected ApiInternalServerError for 500 status, got: {:?}",
 909                error
 910            ),
 911        }
 912    }
 913
 914    #[test]
 915    fn test_from_cloud_failure_with_standard_format() {
 916        let error = LanguageModelCompletionError::from_cloud_failure(
 917            String::from("anthropic").into(),
 918            "upstream_http_503".to_string(),
 919            "Service unavailable".to_string(),
 920            None,
 921        );
 922
 923        match error {
 924            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 925                assert_eq!(provider.0, "anthropic");
 926            }
 927            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
 928        }
 929    }
 930
 931    #[test]
 932    fn test_upstream_http_error_connection_timeout() {
 933        let error = LanguageModelCompletionError::from_cloud_failure(
 934            String::from("anthropic").into(),
 935            "upstream_http_error".to_string(),
 936            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 937            None,
 938        );
 939
 940        match error {
 941            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 942                assert_eq!(provider.0, "anthropic");
 943            }
 944            _ => panic!(
 945                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
 946                error
 947            ),
 948        }
 949
 950        let error = LanguageModelCompletionError::from_cloud_failure(
 951            String::from("anthropic").into(),
 952            "upstream_http_error".to_string(),
 953            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
 954            None,
 955        );
 956
 957        match error {
 958            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 959                assert_eq!(provider.0, "anthropic");
 960                assert_eq!(
 961                    message,
 962                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
 963                );
 964            }
 965            _ => panic!(
 966                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
 967                error
 968            ),
 969        }
 970    }
 971
 972    #[test]
 973    fn test_language_model_tool_use_serializes_with_signature() {
 974        use serde_json::json;
 975
 976        let tool_use = LanguageModelToolUse {
 977            id: LanguageModelToolUseId::from("test_id"),
 978            name: "test_tool".into(),
 979            raw_input: json!({"arg": "value"}).to_string(),
 980            input: json!({"arg": "value"}),
 981            is_input_complete: true,
 982            thought_signature: Some("test_signature".to_string()),
 983        };
 984
 985        let serialized = serde_json::to_value(&tool_use).unwrap();
 986
 987        assert_eq!(serialized["id"], "test_id");
 988        assert_eq!(serialized["name"], "test_tool");
 989        assert_eq!(serialized["thought_signature"], "test_signature");
 990    }
 991
 992    #[test]
 993    fn test_language_model_tool_use_deserializes_with_missing_signature() {
 994        use serde_json::json;
 995
 996        let json = json!({
 997            "id": "test_id",
 998            "name": "test_tool",
 999            "raw_input": "{\"arg\":\"value\"}",
1000            "input": {"arg": "value"},
1001            "is_input_complete": true
1002        });
1003
1004        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1005
1006        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1007        assert_eq!(tool_use.name.as_ref(), "test_tool");
1008        assert_eq!(tool_use.thought_signature, None);
1009    }
1010
1011    #[test]
1012    fn test_language_model_tool_use_round_trip_with_signature() {
1013        use serde_json::json;
1014
1015        let original = LanguageModelToolUse {
1016            id: LanguageModelToolUseId::from("round_trip_id"),
1017            name: "round_trip_tool".into(),
1018            raw_input: json!({"key": "value"}).to_string(),
1019            input: json!({"key": "value"}),
1020            is_input_complete: true,
1021            thought_signature: Some("round_trip_sig".to_string()),
1022        };
1023
1024        let serialized = serde_json::to_value(&original).unwrap();
1025        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1026
1027        assert_eq!(deserialized.id, original.id);
1028        assert_eq!(deserialized.name, original.name);
1029        assert_eq!(deserialized.thought_signature, original.thought_signature);
1030    }
1031
1032    #[test]
1033    fn test_language_model_tool_use_round_trip_without_signature() {
1034        use serde_json::json;
1035
1036        let original = LanguageModelToolUse {
1037            id: LanguageModelToolUseId::from("no_sig_id"),
1038            name: "no_sig_tool".into(),
1039            raw_input: json!({"arg": "value"}).to_string(),
1040            input: json!({"arg": "value"}),
1041            is_input_complete: true,
1042            thought_signature: None,
1043        };
1044
1045        let serialized = serde_json::to_value(&original).unwrap();
1046        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1047
1048        assert_eq!(deserialized.id, original.id);
1049        assert_eq!(deserialized.name, original.name);
1050        assert_eq!(deserialized.thought_signature, None);
1051    }
1052}