language_model.rs

   1mod model;
   2mod rate_limiter;
   3mod registry;
   4mod request;
   5mod role;
   6mod telemetry;
   7pub mod tool_schema;
   8
   9#[cfg(any(test, feature = "test-support"))]
  10pub mod fake_provider;
  11
  12use anthropic::{AnthropicError, parse_prompt_too_long};
  13use anyhow::{Result, anyhow};
  14use client::Client;
  15use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
  16use futures::FutureExt;
  17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
  18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
  19use http_client::{StatusCode, http};
  20use icons::IconName;
  21use open_router::OpenRouterError;
  22use parking_lot::Mutex;
  23use serde::{Deserialize, Serialize};
  24pub use settings::LanguageModelCacheConfiguration;
  25use std::ops::{Add, Sub};
  26use std::str::FromStr;
  27use std::sync::Arc;
  28use std::time::Duration;
  29use std::{fmt, io};
  30use thiserror::Error;
  31use util::serde::is_default;
  32
  33pub use crate::model::*;
  34pub use crate::rate_limiter::*;
  35pub use crate::registry::*;
  36pub use crate::request::*;
  37pub use crate::role::*;
  38pub use crate::telemetry::*;
  39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
  40
  41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
  42    LanguageModelProviderId::new("anthropic");
  43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
  44    LanguageModelProviderName::new("Anthropic");
  45
  46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
  47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
  48    LanguageModelProviderName::new("Google AI");
  49
  50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
  51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
  52    LanguageModelProviderName::new("OpenAI");
  53
  54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
  55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
  56
  57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
  58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
  59    LanguageModelProviderName::new("Zed");
  60
  61pub fn init(client: Arc<Client>, cx: &mut App) {
  62    init_settings(cx);
  63    RefreshLlmTokenListener::register(client, cx);
  64}
  65
  66pub fn init_settings(cx: &mut App) {
  67    registry::init(cx);
  68}
  69
  70/// A completion event from a language model.
  71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
  72pub enum LanguageModelCompletionEvent {
  73    Queued {
  74        position: usize,
  75    },
  76    Started,
  77    UsageUpdated {
  78        amount: usize,
  79        limit: UsageLimit,
  80    },
  81    ToolUseLimitReached,
  82    Stop(StopReason),
  83    Text(String),
  84    Thinking {
  85        text: String,
  86        signature: Option<String>,
  87    },
  88    RedactedThinking {
  89        data: String,
  90    },
  91    ToolUse(LanguageModelToolUse),
  92    ToolUseJsonParseError {
  93        id: LanguageModelToolUseId,
  94        tool_name: Arc<str>,
  95        raw_input: Arc<str>,
  96        json_parse_error: String,
  97    },
  98    StartMessage {
  99        message_id: String,
 100    },
 101    ReasoningDetails(serde_json::Value),
 102    UsageUpdate(TokenUsage),
 103}
 104
 105impl LanguageModelCompletionEvent {
 106    pub fn from_completion_request_status(
 107        status: CompletionRequestStatus,
 108        upstream_provider: LanguageModelProviderName,
 109    ) -> Result<Self, LanguageModelCompletionError> {
 110        match status {
 111            CompletionRequestStatus::Queued { position } => {
 112                Ok(LanguageModelCompletionEvent::Queued { position })
 113            }
 114            CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
 115            CompletionRequestStatus::UsageUpdated { amount, limit } => {
 116                Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit })
 117            }
 118            CompletionRequestStatus::ToolUseLimitReached => {
 119                Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
 120            }
 121            CompletionRequestStatus::Failed {
 122                code,
 123                message,
 124                request_id: _,
 125                retry_after,
 126            } => Err(LanguageModelCompletionError::from_cloud_failure(
 127                upstream_provider,
 128                code,
 129                message,
 130                retry_after.map(Duration::from_secs_f64),
 131            )),
 132        }
 133    }
 134}
 135
 136#[derive(Error, Debug)]
 137pub enum LanguageModelCompletionError {
 138    #[error("prompt too large for context window")]
 139    PromptTooLarge { tokens: Option<u64> },
 140    #[error("missing {provider} API key")]
 141    NoApiKey { provider: LanguageModelProviderName },
 142    #[error("{provider}'s API rate limit exceeded")]
 143    RateLimitExceeded {
 144        provider: LanguageModelProviderName,
 145        retry_after: Option<Duration>,
 146    },
 147    #[error("{provider}'s API servers are overloaded right now")]
 148    ServerOverloaded {
 149        provider: LanguageModelProviderName,
 150        retry_after: Option<Duration>,
 151    },
 152    #[error("{provider}'s API server reported an internal server error: {message}")]
 153    ApiInternalServerError {
 154        provider: LanguageModelProviderName,
 155        message: String,
 156    },
 157    #[error("{message}")]
 158    UpstreamProviderError {
 159        message: String,
 160        status: StatusCode,
 161        retry_after: Option<Duration>,
 162    },
 163    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
 164    HttpResponseError {
 165        provider: LanguageModelProviderName,
 166        status_code: StatusCode,
 167        message: String,
 168    },
 169
 170    // Client errors
 171    #[error("invalid request format to {provider}'s API: {message}")]
 172    BadRequestFormat {
 173        provider: LanguageModelProviderName,
 174        message: String,
 175    },
 176    #[error("authentication error with {provider}'s API: {message}")]
 177    AuthenticationError {
 178        provider: LanguageModelProviderName,
 179        message: String,
 180    },
 181    #[error("Permission error with {provider}'s API: {message}")]
 182    PermissionError {
 183        provider: LanguageModelProviderName,
 184        message: String,
 185    },
 186    #[error("language model provider API endpoint not found")]
 187    ApiEndpointNotFound { provider: LanguageModelProviderName },
 188    #[error("I/O error reading response from {provider}'s API")]
 189    ApiReadResponseError {
 190        provider: LanguageModelProviderName,
 191        #[source]
 192        error: io::Error,
 193    },
 194    #[error("error serializing request to {provider} API")]
 195    SerializeRequest {
 196        provider: LanguageModelProviderName,
 197        #[source]
 198        error: serde_json::Error,
 199    },
 200    #[error("error building request body to {provider} API")]
 201    BuildRequestBody {
 202        provider: LanguageModelProviderName,
 203        #[source]
 204        error: http::Error,
 205    },
 206    #[error("error sending HTTP request to {provider} API")]
 207    HttpSend {
 208        provider: LanguageModelProviderName,
 209        #[source]
 210        error: anyhow::Error,
 211    },
 212    #[error("error deserializing {provider} API response")]
 213    DeserializeResponse {
 214        provider: LanguageModelProviderName,
 215        #[source]
 216        error: serde_json::Error,
 217    },
 218
 219    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
 220    #[error(transparent)]
 221    Other(#[from] anyhow::Error),
 222}
 223
 224impl LanguageModelCompletionError {
 225    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
 226        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
 227        let upstream_status = error_json
 228            .get("upstream_status")
 229            .and_then(|v| v.as_u64())
 230            .and_then(|status| u16::try_from(status).ok())
 231            .and_then(|status| StatusCode::from_u16(status).ok())?;
 232        let inner_message = error_json
 233            .get("message")
 234            .and_then(|v| v.as_str())
 235            .unwrap_or(message)
 236            .to_string();
 237        Some((upstream_status, inner_message))
 238    }
 239
 240    pub fn from_cloud_failure(
 241        upstream_provider: LanguageModelProviderName,
 242        code: String,
 243        message: String,
 244        retry_after: Option<Duration>,
 245    ) -> Self {
 246        if let Some(tokens) = parse_prompt_too_long(&message) {
 247            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
 248            // to be reported. This is a temporary workaround to handle this in the case where the
 249            // token limit has been exceeded.
 250            Self::PromptTooLarge {
 251                tokens: Some(tokens),
 252            }
 253        } else if code == "upstream_http_error" {
 254            if let Some((upstream_status, inner_message)) =
 255                Self::parse_upstream_error_json(&message)
 256            {
 257                return Self::from_http_status(
 258                    upstream_provider,
 259                    upstream_status,
 260                    inner_message,
 261                    retry_after,
 262                );
 263            }
 264            anyhow!("completion request failed, code: {code}, message: {message}").into()
 265        } else if let Some(status_code) = code
 266            .strip_prefix("upstream_http_")
 267            .and_then(|code| StatusCode::from_str(code).ok())
 268        {
 269            Self::from_http_status(upstream_provider, status_code, message, retry_after)
 270        } else if let Some(status_code) = code
 271            .strip_prefix("http_")
 272            .and_then(|code| StatusCode::from_str(code).ok())
 273        {
 274            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
 275        } else {
 276            anyhow!("completion request failed, code: {code}, message: {message}").into()
 277        }
 278    }
 279
 280    pub fn from_http_status(
 281        provider: LanguageModelProviderName,
 282        status_code: StatusCode,
 283        message: String,
 284        retry_after: Option<Duration>,
 285    ) -> Self {
 286        match status_code {
 287            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
 288            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
 289            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
 290            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
 291            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
 292                tokens: parse_prompt_too_long(&message),
 293            },
 294            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
 295                provider,
 296                retry_after,
 297            },
 298            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
 299            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
 300                provider,
 301                retry_after,
 302            },
 303            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
 304                provider,
 305                retry_after,
 306            },
 307            _ => Self::HttpResponseError {
 308                provider,
 309                status_code,
 310                message,
 311            },
 312        }
 313    }
 314}
 315
 316impl From<AnthropicError> for LanguageModelCompletionError {
 317    fn from(error: AnthropicError) -> Self {
 318        let provider = ANTHROPIC_PROVIDER_NAME;
 319        match error {
 320            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 321            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 322            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
 323            AnthropicError::DeserializeResponse(error) => {
 324                Self::DeserializeResponse { provider, error }
 325            }
 326            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 327            AnthropicError::HttpResponseError {
 328                status_code,
 329                message,
 330            } => Self::HttpResponseError {
 331                provider,
 332                status_code,
 333                message,
 334            },
 335            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
 336                provider,
 337                retry_after: Some(retry_after),
 338            },
 339            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 340                provider,
 341                retry_after,
 342            },
 343            AnthropicError::ApiError(api_error) => api_error.into(),
 344        }
 345    }
 346}
 347
 348impl From<anthropic::ApiError> for LanguageModelCompletionError {
 349    fn from(error: anthropic::ApiError) -> Self {
 350        use anthropic::ApiErrorCode::*;
 351        let provider = ANTHROPIC_PROVIDER_NAME;
 352        match error.code() {
 353            Some(code) => match code {
 354                InvalidRequestError => Self::BadRequestFormat {
 355                    provider,
 356                    message: error.message,
 357                },
 358                AuthenticationError => Self::AuthenticationError {
 359                    provider,
 360                    message: error.message,
 361                },
 362                PermissionError => Self::PermissionError {
 363                    provider,
 364                    message: error.message,
 365                },
 366                NotFoundError => Self::ApiEndpointNotFound { provider },
 367                RequestTooLarge => Self::PromptTooLarge {
 368                    tokens: parse_prompt_too_long(&error.message),
 369                },
 370                RateLimitError => Self::RateLimitExceeded {
 371                    provider,
 372                    retry_after: None,
 373                },
 374                ApiError => Self::ApiInternalServerError {
 375                    provider,
 376                    message: error.message,
 377                },
 378                OverloadedError => Self::ServerOverloaded {
 379                    provider,
 380                    retry_after: None,
 381                },
 382            },
 383            None => Self::Other(error.into()),
 384        }
 385    }
 386}
 387
 388impl From<open_ai::RequestError> for LanguageModelCompletionError {
 389    fn from(error: open_ai::RequestError) -> Self {
 390        match error {
 391            open_ai::RequestError::HttpResponseError {
 392                provider,
 393                status_code,
 394                body,
 395                headers,
 396            } => {
 397                let retry_after = headers
 398                    .get(http::header::RETRY_AFTER)
 399                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
 400                    .map(Duration::from_secs);
 401
 402                Self::from_http_status(provider.into(), status_code, body, retry_after)
 403            }
 404            open_ai::RequestError::Other(e) => Self::Other(e),
 405        }
 406    }
 407}
 408
 409impl From<OpenRouterError> for LanguageModelCompletionError {
 410    fn from(error: OpenRouterError) -> Self {
 411        let provider = LanguageModelProviderName::new("OpenRouter");
 412        match error {
 413            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
 414            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
 415            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
 416            OpenRouterError::DeserializeResponse(error) => {
 417                Self::DeserializeResponse { provider, error }
 418            }
 419            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
 420            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
 421                provider,
 422                retry_after: Some(retry_after),
 423            },
 424            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
 425                provider,
 426                retry_after,
 427            },
 428            OpenRouterError::ApiError(api_error) => api_error.into(),
 429        }
 430    }
 431}
 432
 433impl From<open_router::ApiError> for LanguageModelCompletionError {
 434    fn from(error: open_router::ApiError) -> Self {
 435        use open_router::ApiErrorCode::*;
 436        let provider = LanguageModelProviderName::new("OpenRouter");
 437        match error.code {
 438            InvalidRequestError => Self::BadRequestFormat {
 439                provider,
 440                message: error.message,
 441            },
 442            AuthenticationError => Self::AuthenticationError {
 443                provider,
 444                message: error.message,
 445            },
 446            PaymentRequiredError => Self::AuthenticationError {
 447                provider,
 448                message: format!("Payment required: {}", error.message),
 449            },
 450            PermissionError => Self::PermissionError {
 451                provider,
 452                message: error.message,
 453            },
 454            RequestTimedOut => Self::HttpResponseError {
 455                provider,
 456                status_code: StatusCode::REQUEST_TIMEOUT,
 457                message: error.message,
 458            },
 459            RateLimitError => Self::RateLimitExceeded {
 460                provider,
 461                retry_after: None,
 462            },
 463            ApiError => Self::ApiInternalServerError {
 464                provider,
 465                message: error.message,
 466            },
 467            OverloadedError => Self::ServerOverloaded {
 468                provider,
 469                retry_after: None,
 470            },
 471        }
 472    }
 473}
 474
 475#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 476#[serde(rename_all = "snake_case")]
 477pub enum StopReason {
 478    EndTurn,
 479    MaxTokens,
 480    ToolUse,
 481    Refusal,
 482}
 483
 484#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
 485pub struct TokenUsage {
 486    #[serde(default, skip_serializing_if = "is_default")]
 487    pub input_tokens: u64,
 488    #[serde(default, skip_serializing_if = "is_default")]
 489    pub output_tokens: u64,
 490    #[serde(default, skip_serializing_if = "is_default")]
 491    pub cache_creation_input_tokens: u64,
 492    #[serde(default, skip_serializing_if = "is_default")]
 493    pub cache_read_input_tokens: u64,
 494}
 495
 496impl TokenUsage {
 497    pub fn total_tokens(&self) -> u64 {
 498        self.input_tokens
 499            + self.output_tokens
 500            + self.cache_read_input_tokens
 501            + self.cache_creation_input_tokens
 502    }
 503}
 504
 505impl Add<TokenUsage> for TokenUsage {
 506    type Output = Self;
 507
 508    fn add(self, other: Self) -> Self {
 509        Self {
 510            input_tokens: self.input_tokens + other.input_tokens,
 511            output_tokens: self.output_tokens + other.output_tokens,
 512            cache_creation_input_tokens: self.cache_creation_input_tokens
 513                + other.cache_creation_input_tokens,
 514            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
 515        }
 516    }
 517}
 518
 519impl Sub<TokenUsage> for TokenUsage {
 520    type Output = Self;
 521
 522    fn sub(self, other: Self) -> Self {
 523        Self {
 524            input_tokens: self.input_tokens - other.input_tokens,
 525            output_tokens: self.output_tokens - other.output_tokens,
 526            cache_creation_input_tokens: self.cache_creation_input_tokens
 527                - other.cache_creation_input_tokens,
 528            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
 529        }
 530    }
 531}
 532
 533#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 534pub struct LanguageModelToolUseId(Arc<str>);
 535
 536impl fmt::Display for LanguageModelToolUseId {
 537    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 538        write!(f, "{}", self.0)
 539    }
 540}
 541
 542impl<T> From<T> for LanguageModelToolUseId
 543where
 544    T: Into<Arc<str>>,
 545{
 546    fn from(value: T) -> Self {
 547        Self(value.into())
 548    }
 549}
 550
 551#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 552pub struct LanguageModelToolUse {
 553    pub id: LanguageModelToolUseId,
 554    pub name: Arc<str>,
 555    pub raw_input: String,
 556    pub input: serde_json::Value,
 557    pub is_input_complete: bool,
 558    /// Thought signature the model sent us. Some models require that this
 559    /// signature be preserved and sent back in conversation history for validation.
 560    pub thought_signature: Option<String>,
 561}
 562
 563pub struct LanguageModelTextStream {
 564    pub message_id: Option<String>,
 565    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
 566    // Has complete token usage after the stream has finished
 567    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 568}
 569
 570impl Default for LanguageModelTextStream {
 571    fn default() -> Self {
 572        Self {
 573            message_id: None,
 574            stream: Box::pin(futures::stream::empty()),
 575            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
 576        }
 577    }
 578}
 579
 580pub trait LanguageModel: Send + Sync {
 581    fn id(&self) -> LanguageModelId;
 582    fn name(&self) -> LanguageModelName;
 583    fn provider_id(&self) -> LanguageModelProviderId;
 584    fn provider_name(&self) -> LanguageModelProviderName;
 585    fn upstream_provider_id(&self) -> LanguageModelProviderId {
 586        self.provider_id()
 587    }
 588    fn upstream_provider_name(&self) -> LanguageModelProviderName {
 589        self.provider_name()
 590    }
 591
 592    fn telemetry_id(&self) -> String;
 593
 594    fn api_key(&self, _cx: &App) -> Option<String> {
 595        None
 596    }
 597
 598    /// Whether this model supports images
 599    fn supports_images(&self) -> bool;
 600
 601    /// Whether this model supports tools.
 602    fn supports_tools(&self) -> bool;
 603
 604    /// Whether this model supports choosing which tool to use.
 605    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
 606
 607    /// Returns whether this model supports "burn mode";
 608    fn supports_burn_mode(&self) -> bool {
 609        false
 610    }
 611
 612    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 613        LanguageModelToolSchemaFormat::JsonSchema
 614    }
 615
 616    fn max_token_count(&self) -> u64;
 617    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
 618    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
 619        None
 620    }
 621    fn max_output_tokens(&self) -> Option<u64> {
 622        None
 623    }
 624
 625    fn count_tokens(
 626        &self,
 627        request: LanguageModelRequest,
 628        cx: &App,
 629    ) -> BoxFuture<'static, Result<u64>>;
 630
 631    fn stream_completion(
 632        &self,
 633        request: LanguageModelRequest,
 634        cx: &AsyncApp,
 635    ) -> BoxFuture<
 636        'static,
 637        Result<
 638            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 639            LanguageModelCompletionError,
 640        >,
 641    >;
 642
 643    fn stream_completion_text(
 644        &self,
 645        request: LanguageModelRequest,
 646        cx: &AsyncApp,
 647    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
 648        let future = self.stream_completion(request, cx);
 649
 650        async move {
 651            let events = future.await?;
 652            let mut events = events.fuse();
 653            let mut message_id = None;
 654            let mut first_item_text = None;
 655            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 656
 657            if let Some(first_event) = events.next().await {
 658                match first_event {
 659                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
 660                        message_id = Some(id);
 661                    }
 662                    Ok(LanguageModelCompletionEvent::Text(text)) => {
 663                        first_item_text = Some(text);
 664                    }
 665                    _ => (),
 666                }
 667            }
 668
 669            let stream = futures::stream::iter(first_item_text.map(Ok))
 670                .chain(events.filter_map({
 671                    let last_token_usage = last_token_usage.clone();
 672                    move |result| {
 673                        let last_token_usage = last_token_usage.clone();
 674                        async move {
 675                            match result {
 676                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
 677                                Ok(LanguageModelCompletionEvent::Started) => None,
 678                                Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None,
 679                                Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
 680                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
 681                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
 682                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
 683                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
 684                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
 685                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
 686                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
 687                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 688                                    ..
 689                                }) => None,
 690                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
 691                                    *last_token_usage.lock() = token_usage;
 692                                    None
 693                                }
 694                                Err(err) => Some(Err(err)),
 695                            }
 696                        }
 697                    }
 698                }))
 699                .boxed();
 700
 701            Ok(LanguageModelTextStream {
 702                message_id,
 703                stream,
 704                last_token_usage,
 705            })
 706        }
 707        .boxed()
 708    }
 709
 710    fn stream_completion_tool(
 711        &self,
 712        request: LanguageModelRequest,
 713        cx: &AsyncApp,
 714    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
 715        let future = self.stream_completion(request, cx);
 716
 717        async move {
 718            let events = future.await?;
 719            let mut events = events.fuse();
 720
 721            // Iterate through events until we find a complete ToolUse
 722            while let Some(event) = events.next().await {
 723                match event {
 724                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
 725                        if tool_use.is_input_complete =>
 726                    {
 727                        return Ok(tool_use);
 728                    }
 729                    Err(err) => {
 730                        return Err(err);
 731                    }
 732                    _ => {}
 733                }
 734            }
 735
 736            // Stream ended without a complete tool use
 737            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
 738                "Stream ended without receiving a complete tool use"
 739            )))
 740        }
 741        .boxed()
 742    }
 743
 744    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
 745        None
 746    }
 747
 748    #[cfg(any(test, feature = "test-support"))]
 749    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
 750        unimplemented!()
 751    }
 752}
 753
 754pub trait LanguageModelExt: LanguageModel {
 755    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
 756        match mode {
 757            CompletionMode::Normal => self.max_token_count(),
 758            CompletionMode::Max => self
 759                .max_token_count_in_burn_mode()
 760                .unwrap_or_else(|| self.max_token_count()),
 761        }
 762    }
 763}
 764impl LanguageModelExt for dyn LanguageModel {}
 765
 766/// An error that occurred when trying to authenticate the language model provider.
 767#[derive(Debug, Error)]
 768pub enum AuthenticateError {
 769    #[error("connection refused")]
 770    ConnectionRefused,
 771    #[error("credentials not found")]
 772    CredentialsNotFound,
 773    #[error(transparent)]
 774    Other(#[from] anyhow::Error),
 775}
 776
 777pub trait LanguageModelProvider: 'static {
 778    fn id(&self) -> LanguageModelProviderId;
 779    fn name(&self) -> LanguageModelProviderName;
 780    fn icon(&self) -> IconName {
 781        IconName::ZedAssistant
 782    }
 783    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 784    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
 785    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
 786    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 787        Vec::new()
 788    }
 789    fn is_authenticated(&self, cx: &App) -> bool;
 790    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
 791    fn configuration_view(
 792        &self,
 793        target_agent: ConfigurationViewTargetAgent,
 794        window: &mut Window,
 795        cx: &mut App,
 796    ) -> AnyView;
 797    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
 798}
 799
 800#[derive(Default, Clone)]
 801pub enum ConfigurationViewTargetAgent {
 802    #[default]
 803    ZedAgent,
 804    Other(SharedString),
 805}
 806
 807#[derive(PartialEq, Eq)]
 808pub enum LanguageModelProviderTosView {
 809    /// When there are some past interactions in the Agent Panel.
 810    ThreadEmptyState,
 811    /// When there are no past interactions in the Agent Panel.
 812    ThreadFreshStart,
 813    TextThreadPopup,
 814    Configuration,
 815}
 816
 817pub trait LanguageModelProviderState: 'static {
 818    type ObservableEntity;
 819
 820    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
 821
 822    fn subscribe<T: 'static>(
 823        &self,
 824        cx: &mut gpui::Context<T>,
 825        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
 826    ) -> Option<gpui::Subscription> {
 827        let entity = self.observable_entity()?;
 828        Some(cx.observe(&entity, move |this, _, cx| {
 829            callback(this, cx);
 830        }))
 831    }
 832}
 833
 834#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
 835pub struct LanguageModelId(pub SharedString);
 836
 837#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 838pub struct LanguageModelName(pub SharedString);
 839
 840#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 841pub struct LanguageModelProviderId(pub SharedString);
 842
 843#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 844pub struct LanguageModelProviderName(pub SharedString);
 845
 846impl LanguageModelProviderId {
 847    pub const fn new(id: &'static str) -> Self {
 848        Self(SharedString::new_static(id))
 849    }
 850}
 851
 852impl LanguageModelProviderName {
 853    pub const fn new(id: &'static str) -> Self {
 854        Self(SharedString::new_static(id))
 855    }
 856}
 857
 858impl fmt::Display for LanguageModelProviderId {
 859    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 860        write!(f, "{}", self.0)
 861    }
 862}
 863
 864impl fmt::Display for LanguageModelProviderName {
 865    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 866        write!(f, "{}", self.0)
 867    }
 868}
 869
 870impl From<String> for LanguageModelId {
 871    fn from(value: String) -> Self {
 872        Self(SharedString::from(value))
 873    }
 874}
 875
 876impl From<String> for LanguageModelName {
 877    fn from(value: String) -> Self {
 878        Self(SharedString::from(value))
 879    }
 880}
 881
 882impl From<String> for LanguageModelProviderId {
 883    fn from(value: String) -> Self {
 884        Self(SharedString::from(value))
 885    }
 886}
 887
 888impl From<String> for LanguageModelProviderName {
 889    fn from(value: String) -> Self {
 890        Self(SharedString::from(value))
 891    }
 892}
 893
 894impl From<Arc<str>> for LanguageModelProviderId {
 895    fn from(value: Arc<str>) -> Self {
 896        Self(SharedString::from(value))
 897    }
 898}
 899
 900impl From<Arc<str>> for LanguageModelProviderName {
 901    fn from(value: Arc<str>) -> Self {
 902        Self(SharedString::from(value))
 903    }
 904}
 905
 906#[cfg(test)]
 907mod tests {
 908    use super::*;
 909
 910    #[test]
 911    fn test_from_cloud_failure_with_upstream_http_error() {
 912        let error = LanguageModelCompletionError::from_cloud_failure(
 913            String::from("anthropic").into(),
 914            "upstream_http_error".to_string(),
 915            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 916            None,
 917        );
 918
 919        match error {
 920            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 921                assert_eq!(provider.0, "anthropic");
 922            }
 923            _ => panic!(
 924                "Expected ServerOverloaded error for 503 status, got: {:?}",
 925                error
 926            ),
 927        }
 928
 929        let error = LanguageModelCompletionError::from_cloud_failure(
 930            String::from("anthropic").into(),
 931            "upstream_http_error".to_string(),
 932            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
 933            None,
 934        );
 935
 936        match error {
 937            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 938                assert_eq!(provider.0, "anthropic");
 939                assert_eq!(message, "Internal server error");
 940            }
 941            _ => panic!(
 942                "Expected ApiInternalServerError for 500 status, got: {:?}",
 943                error
 944            ),
 945        }
 946    }
 947
 948    #[test]
 949    fn test_from_cloud_failure_with_standard_format() {
 950        let error = LanguageModelCompletionError::from_cloud_failure(
 951            String::from("anthropic").into(),
 952            "upstream_http_503".to_string(),
 953            "Service unavailable".to_string(),
 954            None,
 955        );
 956
 957        match error {
 958            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 959                assert_eq!(provider.0, "anthropic");
 960            }
 961            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
 962        }
 963    }
 964
 965    #[test]
 966    fn test_upstream_http_error_connection_timeout() {
 967        let error = LanguageModelCompletionError::from_cloud_failure(
 968            String::from("anthropic").into(),
 969            "upstream_http_error".to_string(),
 970            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
 971            None,
 972        );
 973
 974        match error {
 975            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
 976                assert_eq!(provider.0, "anthropic");
 977            }
 978            _ => panic!(
 979                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
 980                error
 981            ),
 982        }
 983
 984        let error = LanguageModelCompletionError::from_cloud_failure(
 985            String::from("anthropic").into(),
 986            "upstream_http_error".to_string(),
 987            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
 988            None,
 989        );
 990
 991        match error {
 992            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
 993                assert_eq!(provider.0, "anthropic");
 994                assert_eq!(
 995                    message,
 996                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
 997                );
 998            }
 999            _ => panic!(
1000                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1001                error
1002            ),
1003        }
1004    }
1005
1006    #[test]
1007    fn test_language_model_tool_use_serializes_with_signature() {
1008        use serde_json::json;
1009
1010        let tool_use = LanguageModelToolUse {
1011            id: LanguageModelToolUseId::from("test_id"),
1012            name: "test_tool".into(),
1013            raw_input: json!({"arg": "value"}).to_string(),
1014            input: json!({"arg": "value"}),
1015            is_input_complete: true,
1016            thought_signature: Some("test_signature".to_string()),
1017        };
1018
1019        let serialized = serde_json::to_value(&tool_use).unwrap();
1020
1021        assert_eq!(serialized["id"], "test_id");
1022        assert_eq!(serialized["name"], "test_tool");
1023        assert_eq!(serialized["thought_signature"], "test_signature");
1024    }
1025
1026    #[test]
1027    fn test_language_model_tool_use_deserializes_with_missing_signature() {
1028        use serde_json::json;
1029
1030        let json = json!({
1031            "id": "test_id",
1032            "name": "test_tool",
1033            "raw_input": "{\"arg\":\"value\"}",
1034            "input": {"arg": "value"},
1035            "is_input_complete": true
1036        });
1037
1038        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1039
1040        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1041        assert_eq!(tool_use.name.as_ref(), "test_tool");
1042        assert_eq!(tool_use.thought_signature, None);
1043    }
1044
1045    #[test]
1046    fn test_language_model_tool_use_round_trip_with_signature() {
1047        use serde_json::json;
1048
1049        let original = LanguageModelToolUse {
1050            id: LanguageModelToolUseId::from("round_trip_id"),
1051            name: "round_trip_tool".into(),
1052            raw_input: json!({"key": "value"}).to_string(),
1053            input: json!({"key": "value"}),
1054            is_input_complete: true,
1055            thought_signature: Some("round_trip_sig".to_string()),
1056        };
1057
1058        let serialized = serde_json::to_value(&original).unwrap();
1059        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1060
1061        assert_eq!(deserialized.id, original.id);
1062        assert_eq!(deserialized.name, original.name);
1063        assert_eq!(deserialized.thought_signature, original.thought_signature);
1064    }
1065
1066    #[test]
1067    fn test_language_model_tool_use_round_trip_without_signature() {
1068        use serde_json::json;
1069
1070        let original = LanguageModelToolUse {
1071            id: LanguageModelToolUseId::from("no_sig_id"),
1072            name: "no_sig_tool".into(),
1073            raw_input: json!({"arg": "value"}).to_string(),
1074            input: json!({"arg": "value"}),
1075            is_input_complete: true,
1076            thought_signature: None,
1077        };
1078
1079        let serialized = serde_json::to_value(&original).unwrap();
1080        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1081
1082        assert_eq!(deserialized.id, original.id);
1083        assert_eq!(deserialized.name, original.name);
1084        assert_eq!(deserialized.thought_signature, None);
1085    }
1086}