language_model.rs

  1mod api_key;
  2mod model;
  3mod provider;
  4mod rate_limiter;
  5mod registry;
  6mod request;
  7mod role;
  8pub mod tool_schema;
  9
 10#[cfg(any(test, feature = "test-support"))]
 11pub mod fake_provider;
 12
 13use anyhow::{Result, anyhow};
 14use cloud_llm_client::CompletionRequestStatus;
 15use futures::FutureExt;
 16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 18use http_client::{StatusCode, http};
 19use icons::IconName;
 20use parking_lot::Mutex;
 21use serde::{Deserialize, Serialize};
 22use std::ops::{Add, Sub};
 23use std::str::FromStr;
 24use std::sync::Arc;
 25use std::time::Duration;
 26use std::{fmt, io};
 27use thiserror::Error;
 28use util::serde::is_default;
 29
 30pub use crate::api_key::{ApiKey, ApiKeyState};
 31pub use crate::model::*;
 32pub use crate::rate_limiter::*;
 33pub use crate::registry::*;
 34pub use crate::request::*;
 35pub use crate::role::*;
 36pub use crate::tool_schema::LanguageModelToolSchemaFormat;
 37pub use env_var::{EnvVar, env_var};
 38pub use provider::*;
 39
 40pub fn init(cx: &mut App) {
 41    registry::init(cx);
 42}
 43
 44#[derive(Clone, Debug)]
 45pub struct LanguageModelCacheConfiguration {
 46    pub max_cache_anchors: usize,
 47    pub should_speculate: bool,
 48    pub min_total_token: u64,
 49}
 50
 51/// A completion event from a language model.
 52#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 53pub enum LanguageModelCompletionEvent {
 54    Queued {
 55        position: usize,
 56    },
 57    Started,
 58    Stop(StopReason),
 59    Text(String),
 60    Thinking {
 61        text: String,
 62        signature: Option<String>,
 63    },
 64    RedactedThinking {
 65        data: String,
 66    },
 67    ToolUse(LanguageModelToolUse),
 68    ToolUseJsonParseError {
 69        id: LanguageModelToolUseId,
 70        tool_name: Arc<str>,
 71        raw_input: Arc<str>,
 72        json_parse_error: String,
 73    },
 74    StartMessage {
 75        message_id: String,
 76    },
 77    ReasoningDetails(serde_json::Value),
 78    UsageUpdate(TokenUsage),
 79}
 80
 81impl LanguageModelCompletionEvent {
 82    pub fn from_completion_request_status(
 83        status: CompletionRequestStatus,
 84        upstream_provider: LanguageModelProviderName,
 85    ) -> Result<Option<Self>, LanguageModelCompletionError> {
 86        match status {
 87            CompletionRequestStatus::Queued { position } => {
 88                Ok(Some(LanguageModelCompletionEvent::Queued { position }))
 89            }
 90            CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
 91            CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
 92            CompletionRequestStatus::Failed {
 93                code,
 94                message,
 95                request_id: _,
 96                retry_after,
 97            } => Err(LanguageModelCompletionError::from_cloud_failure(
 98                upstream_provider,
 99                code,
100                message,
101                retry_after.map(Duration::from_secs_f64),
102            )),
103        }
104    }
105}
106
107#[derive(Error, Debug)]
108pub enum LanguageModelCompletionError {
109    #[error("prompt too large for context window")]
110    PromptTooLarge { tokens: Option<u64> },
111    #[error("missing {provider} API key")]
112    NoApiKey { provider: LanguageModelProviderName },
113    #[error("{provider}'s API rate limit exceeded")]
114    RateLimitExceeded {
115        provider: LanguageModelProviderName,
116        retry_after: Option<Duration>,
117    },
118    #[error("{provider}'s API servers are overloaded right now")]
119    ServerOverloaded {
120        provider: LanguageModelProviderName,
121        retry_after: Option<Duration>,
122    },
123    #[error("{provider}'s API server reported an internal server error: {message}")]
124    ApiInternalServerError {
125        provider: LanguageModelProviderName,
126        message: String,
127    },
128    #[error("{message}")]
129    UpstreamProviderError {
130        message: String,
131        status: StatusCode,
132        retry_after: Option<Duration>,
133    },
134    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
135    HttpResponseError {
136        provider: LanguageModelProviderName,
137        status_code: StatusCode,
138        message: String,
139    },
140
141    // Client errors
142    #[error("invalid request format to {provider}'s API: {message}")]
143    BadRequestFormat {
144        provider: LanguageModelProviderName,
145        message: String,
146    },
147    #[error("authentication error with {provider}'s API: {message}")]
148    AuthenticationError {
149        provider: LanguageModelProviderName,
150        message: String,
151    },
152    #[error("Permission error with {provider}'s API: {message}")]
153    PermissionError {
154        provider: LanguageModelProviderName,
155        message: String,
156    },
157    #[error("language model provider API endpoint not found")]
158    ApiEndpointNotFound { provider: LanguageModelProviderName },
159    #[error("I/O error reading response from {provider}'s API")]
160    ApiReadResponseError {
161        provider: LanguageModelProviderName,
162        #[source]
163        error: io::Error,
164    },
165    #[error("error serializing request to {provider} API")]
166    SerializeRequest {
167        provider: LanguageModelProviderName,
168        #[source]
169        error: serde_json::Error,
170    },
171    #[error("error building request body to {provider} API")]
172    BuildRequestBody {
173        provider: LanguageModelProviderName,
174        #[source]
175        error: http::Error,
176    },
177    #[error("error sending HTTP request to {provider} API")]
178    HttpSend {
179        provider: LanguageModelProviderName,
180        #[source]
181        error: anyhow::Error,
182    },
183    #[error("error deserializing {provider} API response")]
184    DeserializeResponse {
185        provider: LanguageModelProviderName,
186        #[source]
187        error: serde_json::Error,
188    },
189
190    #[error("stream from {provider} ended unexpectedly")]
191    StreamEndedUnexpectedly { provider: LanguageModelProviderName },
192
193    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
194    #[error(transparent)]
195    Other(#[from] anyhow::Error),
196}
197
198impl LanguageModelCompletionError {
199    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
200        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
201        let upstream_status = error_json
202            .get("upstream_status")
203            .and_then(|v| v.as_u64())
204            .and_then(|status| u16::try_from(status).ok())
205            .and_then(|status| StatusCode::from_u16(status).ok())?;
206        let inner_message = error_json
207            .get("message")
208            .and_then(|v| v.as_str())
209            .unwrap_or(message)
210            .to_string();
211        Some((upstream_status, inner_message))
212    }
213
214    pub fn from_cloud_failure(
215        upstream_provider: LanguageModelProviderName,
216        code: String,
217        message: String,
218        retry_after: Option<Duration>,
219    ) -> Self {
220        if let Some(tokens) = parse_prompt_too_long(&message) {
221            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
222            // to be reported. This is a temporary workaround to handle this in the case where the
223            // token limit has been exceeded.
224            Self::PromptTooLarge {
225                tokens: Some(tokens),
226            }
227        } else if code == "upstream_http_error" {
228            if let Some((upstream_status, inner_message)) =
229                Self::parse_upstream_error_json(&message)
230            {
231                return Self::from_http_status(
232                    upstream_provider,
233                    upstream_status,
234                    inner_message,
235                    retry_after,
236                );
237            }
238            anyhow!("completion request failed, code: {code}, message: {message}").into()
239        } else if let Some(status_code) = code
240            .strip_prefix("upstream_http_")
241            .and_then(|code| StatusCode::from_str(code).ok())
242        {
243            Self::from_http_status(upstream_provider, status_code, message, retry_after)
244        } else if let Some(status_code) = code
245            .strip_prefix("http_")
246            .and_then(|code| StatusCode::from_str(code).ok())
247        {
248            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
249        } else {
250            anyhow!("completion request failed, code: {code}, message: {message}").into()
251        }
252    }
253
254    pub fn from_http_status(
255        provider: LanguageModelProviderName,
256        status_code: StatusCode,
257        message: String,
258        retry_after: Option<Duration>,
259    ) -> Self {
260        match status_code {
261            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
262            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
263            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
264            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
265            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
266                tokens: parse_prompt_too_long(&message),
267            },
268            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
269                provider,
270                retry_after,
271            },
272            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
273            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
274                provider,
275                retry_after,
276            },
277            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
278                provider,
279                retry_after,
280            },
281            _ => Self::HttpResponseError {
282                provider,
283                status_code,
284                message,
285            },
286        }
287    }
288}
289
290#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
291#[serde(rename_all = "snake_case")]
292pub enum StopReason {
293    EndTurn,
294    MaxTokens,
295    ToolUse,
296    Refusal,
297}
298
299#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
300pub struct TokenUsage {
301    #[serde(default, skip_serializing_if = "is_default")]
302    pub input_tokens: u64,
303    #[serde(default, skip_serializing_if = "is_default")]
304    pub output_tokens: u64,
305    #[serde(default, skip_serializing_if = "is_default")]
306    pub cache_creation_input_tokens: u64,
307    #[serde(default, skip_serializing_if = "is_default")]
308    pub cache_read_input_tokens: u64,
309}
310
311impl TokenUsage {
312    pub fn total_tokens(&self) -> u64 {
313        self.input_tokens
314            + self.output_tokens
315            + self.cache_read_input_tokens
316            + self.cache_creation_input_tokens
317    }
318}
319
320impl Add<TokenUsage> for TokenUsage {
321    type Output = Self;
322
323    fn add(self, other: Self) -> Self {
324        Self {
325            input_tokens: self.input_tokens + other.input_tokens,
326            output_tokens: self.output_tokens + other.output_tokens,
327            cache_creation_input_tokens: self.cache_creation_input_tokens
328                + other.cache_creation_input_tokens,
329            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
330        }
331    }
332}
333
334impl Sub<TokenUsage> for TokenUsage {
335    type Output = Self;
336
337    fn sub(self, other: Self) -> Self {
338        Self {
339            input_tokens: self.input_tokens - other.input_tokens,
340            output_tokens: self.output_tokens - other.output_tokens,
341            cache_creation_input_tokens: self.cache_creation_input_tokens
342                - other.cache_creation_input_tokens,
343            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
344        }
345    }
346}
347
348#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
349pub struct LanguageModelToolUseId(Arc<str>);
350
351impl fmt::Display for LanguageModelToolUseId {
352    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353        write!(f, "{}", self.0)
354    }
355}
356
357impl<T> From<T> for LanguageModelToolUseId
358where
359    T: Into<Arc<str>>,
360{
361    fn from(value: T) -> Self {
362        Self(value.into())
363    }
364}
365
366#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
367pub struct LanguageModelToolUse {
368    pub id: LanguageModelToolUseId,
369    pub name: Arc<str>,
370    pub raw_input: String,
371    pub input: serde_json::Value,
372    pub is_input_complete: bool,
373    /// Thought signature the model sent us. Some models require that this
374    /// signature be preserved and sent back in conversation history for validation.
375    pub thought_signature: Option<String>,
376}
377
378pub struct LanguageModelTextStream {
379    pub message_id: Option<String>,
380    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
381    // Has complete token usage after the stream has finished
382    pub last_token_usage: Arc<Mutex<TokenUsage>>,
383}
384
385impl Default for LanguageModelTextStream {
386    fn default() -> Self {
387        Self {
388            message_id: None,
389            stream: Box::pin(futures::stream::empty()),
390            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
391        }
392    }
393}
394
395#[derive(Debug, Clone)]
396pub struct LanguageModelEffortLevel {
397    pub name: SharedString,
398    pub value: SharedString,
399    pub is_default: bool,
400}
401
402pub trait LanguageModel: Send + Sync {
403    fn id(&self) -> LanguageModelId;
404    fn name(&self) -> LanguageModelName;
405    fn provider_id(&self) -> LanguageModelProviderId;
406    fn provider_name(&self) -> LanguageModelProviderName;
407    fn upstream_provider_id(&self) -> LanguageModelProviderId {
408        self.provider_id()
409    }
410    fn upstream_provider_name(&self) -> LanguageModelProviderName {
411        self.provider_name()
412    }
413
414    /// Returns whether this model is the "latest", so we can highlight it in the UI.
415    fn is_latest(&self) -> bool {
416        false
417    }
418
419    fn telemetry_id(&self) -> String;
420
421    fn api_key(&self, _cx: &App) -> Option<String> {
422        None
423    }
424
425    /// Information about the cost of using this model, if available.
426    fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
427        None
428    }
429
430    /// Whether this model supports thinking.
431    fn supports_thinking(&self) -> bool {
432        false
433    }
434
435    fn supports_fast_mode(&self) -> bool {
436        false
437    }
438
439    /// Returns the list of supported effort levels that can be used when thinking.
440    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
441        Vec::new()
442    }
443
444    /// Returns the default effort level to use when thinking.
445    fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
446        self.supported_effort_levels()
447            .into_iter()
448            .find(|effort_level| effort_level.is_default)
449    }
450
451    /// Whether this model supports images
452    fn supports_images(&self) -> bool;
453
454    /// Whether this model supports tools.
455    fn supports_tools(&self) -> bool;
456
457    /// Whether this model supports choosing which tool to use.
458    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
459
460    /// Returns whether this model or provider supports streaming tool calls;
461    fn supports_streaming_tools(&self) -> bool {
462        false
463    }
464
465    /// Returns whether this model/provider reports accurate split input/output token counts.
466    /// When true, the UI may show separate input/output token indicators.
467    fn supports_split_token_display(&self) -> bool {
468        false
469    }
470
471    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
472        LanguageModelToolSchemaFormat::JsonSchema
473    }
474
475    fn max_token_count(&self) -> u64;
476    fn max_output_tokens(&self) -> Option<u64> {
477        None
478    }
479
480    fn count_tokens(
481        &self,
482        request: LanguageModelRequest,
483        cx: &App,
484    ) -> BoxFuture<'static, Result<u64>>;
485
486    fn stream_completion(
487        &self,
488        request: LanguageModelRequest,
489        cx: &AsyncApp,
490    ) -> BoxFuture<
491        'static,
492        Result<
493            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
494            LanguageModelCompletionError,
495        >,
496    >;
497
498    fn stream_completion_text(
499        &self,
500        request: LanguageModelRequest,
501        cx: &AsyncApp,
502    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
503        let future = self.stream_completion(request, cx);
504
505        async move {
506            let events = future.await?;
507            let mut events = events.fuse();
508            let mut message_id = None;
509            let mut first_item_text = None;
510            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
511
512            if let Some(first_event) = events.next().await {
513                match first_event {
514                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
515                        message_id = Some(id);
516                    }
517                    Ok(LanguageModelCompletionEvent::Text(text)) => {
518                        first_item_text = Some(text);
519                    }
520                    _ => (),
521                }
522            }
523
524            let stream = futures::stream::iter(first_item_text.map(Ok))
525                .chain(events.filter_map({
526                    let last_token_usage = last_token_usage.clone();
527                    move |result| {
528                        let last_token_usage = last_token_usage.clone();
529                        async move {
530                            match result {
531                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
532                                Ok(LanguageModelCompletionEvent::Started) => None,
533                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
534                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
535                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
536                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
537                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
538                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
539                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
540                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
541                                    ..
542                                }) => None,
543                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
544                                    *last_token_usage.lock() = token_usage;
545                                    None
546                                }
547                                Err(err) => Some(Err(err)),
548                            }
549                        }
550                    }
551                }))
552                .boxed();
553
554            Ok(LanguageModelTextStream {
555                message_id,
556                stream,
557                last_token_usage,
558            })
559        }
560        .boxed()
561    }
562
563    fn stream_completion_tool(
564        &self,
565        request: LanguageModelRequest,
566        cx: &AsyncApp,
567    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
568        let future = self.stream_completion(request, cx);
569
570        async move {
571            let events = future.await?;
572            let mut events = events.fuse();
573
574            // Iterate through events until we find a complete ToolUse
575            while let Some(event) = events.next().await {
576                match event {
577                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
578                        if tool_use.is_input_complete =>
579                    {
580                        return Ok(tool_use);
581                    }
582                    Err(err) => {
583                        return Err(err);
584                    }
585                    _ => {}
586                }
587            }
588
589            // Stream ended without a complete tool use
590            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
591                "Stream ended without receiving a complete tool use"
592            )))
593        }
594        .boxed()
595    }
596
597    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
598        None
599    }
600
601    #[cfg(any(test, feature = "test-support"))]
602    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
603        unimplemented!()
604    }
605}
606
607impl std::fmt::Debug for dyn LanguageModel {
608    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
609        f.debug_struct("<dyn LanguageModel>")
610            .field("id", &self.id())
611            .field("name", &self.name())
612            .field("provider_id", &self.provider_id())
613            .field("provider_name", &self.provider_name())
614            .field("upstream_provider_name", &self.upstream_provider_name())
615            .field("upstream_provider_id", &self.upstream_provider_id())
616            .field("upstream_provider_id", &self.upstream_provider_id())
617            .field("supports_streaming_tools", &self.supports_streaming_tools())
618            .finish()
619    }
620}
621
622/// An error that occurred when trying to authenticate the language model provider.
623#[derive(Debug, Error)]
624pub enum AuthenticateError {
625    #[error("connection refused")]
626    ConnectionRefused,
627    #[error("credentials not found")]
628    CredentialsNotFound,
629    #[error(transparent)]
630    Other(#[from] anyhow::Error),
631}
632
633/// Either a built-in icon name or a path to an external SVG.
634#[derive(Debug, Clone, PartialEq, Eq)]
635pub enum IconOrSvg {
636    /// A built-in icon from Zed's icon set.
637    Icon(IconName),
638    /// Path to a custom SVG icon file.
639    Svg(SharedString),
640}
641
642impl Default for IconOrSvg {
643    fn default() -> Self {
644        Self::Icon(IconName::ZedAssistant)
645    }
646}
647
648pub trait LanguageModelProvider: 'static {
649    fn id(&self) -> LanguageModelProviderId;
650    fn name(&self) -> LanguageModelProviderName;
651    fn icon(&self) -> IconOrSvg {
652        IconOrSvg::default()
653    }
654    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
655    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
656    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
657    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
658        Vec::new()
659    }
660    fn is_authenticated(&self, cx: &App) -> bool;
661    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
662    fn configuration_view(
663        &self,
664        target_agent: ConfigurationViewTargetAgent,
665        window: &mut Window,
666        cx: &mut App,
667    ) -> AnyView;
668    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
669}
670
671#[derive(Default, Clone, PartialEq, Eq)]
672pub enum ConfigurationViewTargetAgent {
673    #[default]
674    ZedAgent,
675    Other(SharedString),
676}
677
678pub trait LanguageModelProviderState: 'static {
679    type ObservableEntity;
680
681    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
682
683    fn subscribe<T: 'static>(
684        &self,
685        cx: &mut gpui::Context<T>,
686        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
687    ) -> Option<gpui::Subscription> {
688        let entity = self.observable_entity()?;
689        Some(cx.observe(&entity, move |this, _, cx| {
690            callback(this, cx);
691        }))
692    }
693}
694
695#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
696pub struct LanguageModelId(pub SharedString);
697
698#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
699pub struct LanguageModelName(pub SharedString);
700
701#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
702pub struct LanguageModelProviderId(pub SharedString);
703
704#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
705pub struct LanguageModelProviderName(pub SharedString);
706
707#[derive(Clone, Debug, PartialEq)]
708pub enum LanguageModelCostInfo {
709    /// Cost per 1,000 input and output tokens
710    TokenCost {
711        input_token_cost_per_1m: f64,
712        output_token_cost_per_1m: f64,
713    },
714    /// Cost per request
715    RequestCost { cost_per_request: f64 },
716}
717
718impl LanguageModelCostInfo {
719    pub fn to_shared_string(&self) -> SharedString {
720        match self {
721            LanguageModelCostInfo::RequestCost { cost_per_request } => {
722                let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
723                SharedString::from(cost_str)
724            }
725            LanguageModelCostInfo::TokenCost {
726                input_token_cost_per_1m,
727                output_token_cost_per_1m,
728            } => {
729                let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
730                let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
731                SharedString::from(format!("{}$/{}$", input_cost, output_cost))
732            }
733        }
734    }
735
736    fn cost_value_to_string(cost: &f64) -> SharedString {
737        if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
738            SharedString::from(format!("{:.0}", cost))
739        } else {
740            SharedString::from(format!("{:.2}", cost))
741        }
742    }
743}
744
745impl LanguageModelProviderId {
746    pub const fn new(id: &'static str) -> Self {
747        Self(SharedString::new_static(id))
748    }
749}
750
751impl LanguageModelProviderName {
752    pub const fn new(id: &'static str) -> Self {
753        Self(SharedString::new_static(id))
754    }
755}
756
757impl fmt::Display for LanguageModelProviderId {
758    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
759        write!(f, "{}", self.0)
760    }
761}
762
763impl fmt::Display for LanguageModelProviderName {
764    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
765        write!(f, "{}", self.0)
766    }
767}
768
769impl From<String> for LanguageModelId {
770    fn from(value: String) -> Self {
771        Self(SharedString::from(value))
772    }
773}
774
775impl From<String> for LanguageModelName {
776    fn from(value: String) -> Self {
777        Self(SharedString::from(value))
778    }
779}
780
781impl From<String> for LanguageModelProviderId {
782    fn from(value: String) -> Self {
783        Self(SharedString::from(value))
784    }
785}
786
787impl From<String> for LanguageModelProviderName {
788    fn from(value: String) -> Self {
789        Self(SharedString::from(value))
790    }
791}
792
793impl From<Arc<str>> for LanguageModelProviderId {
794    fn from(value: Arc<str>) -> Self {
795        Self(SharedString::from(value))
796    }
797}
798
799impl From<Arc<str>> for LanguageModelProviderName {
800    fn from(value: Arc<str>) -> Self {
801        Self(SharedString::from(value))
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use super::*;
808
809    #[test]
810    fn test_from_cloud_failure_with_upstream_http_error() {
811        let error = LanguageModelCompletionError::from_cloud_failure(
812            String::from("anthropic").into(),
813            "upstream_http_error".to_string(),
814            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
815            None,
816        );
817
818        match error {
819            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
820                assert_eq!(provider.0, "anthropic");
821            }
822            _ => panic!(
823                "Expected ServerOverloaded error for 503 status, got: {:?}",
824                error
825            ),
826        }
827
828        let error = LanguageModelCompletionError::from_cloud_failure(
829            String::from("anthropic").into(),
830            "upstream_http_error".to_string(),
831            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
832            None,
833        );
834
835        match error {
836            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
837                assert_eq!(provider.0, "anthropic");
838                assert_eq!(message, "Internal server error");
839            }
840            _ => panic!(
841                "Expected ApiInternalServerError for 500 status, got: {:?}",
842                error
843            ),
844        }
845    }
846
847    #[test]
848    fn test_from_cloud_failure_with_standard_format() {
849        let error = LanguageModelCompletionError::from_cloud_failure(
850            String::from("anthropic").into(),
851            "upstream_http_503".to_string(),
852            "Service unavailable".to_string(),
853            None,
854        );
855
856        match error {
857            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
858                assert_eq!(provider.0, "anthropic");
859            }
860            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
861        }
862    }
863
864    #[test]
865    fn test_upstream_http_error_connection_timeout() {
866        let error = LanguageModelCompletionError::from_cloud_failure(
867            String::from("anthropic").into(),
868            "upstream_http_error".to_string(),
869            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
870            None,
871        );
872
873        match error {
874            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
875                assert_eq!(provider.0, "anthropic");
876            }
877            _ => panic!(
878                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
879                error
880            ),
881        }
882
883        let error = LanguageModelCompletionError::from_cloud_failure(
884            String::from("anthropic").into(),
885            "upstream_http_error".to_string(),
886            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
887            None,
888        );
889
890        match error {
891            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
892                assert_eq!(provider.0, "anthropic");
893                assert_eq!(
894                    message,
895                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
896                );
897            }
898            _ => panic!(
899                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
900                error
901            ),
902        }
903    }
904
905    #[test]
906    fn test_language_model_tool_use_serializes_with_signature() {
907        use serde_json::json;
908
909        let tool_use = LanguageModelToolUse {
910            id: LanguageModelToolUseId::from("test_id"),
911            name: "test_tool".into(),
912            raw_input: json!({"arg": "value"}).to_string(),
913            input: json!({"arg": "value"}),
914            is_input_complete: true,
915            thought_signature: Some("test_signature".to_string()),
916        };
917
918        let serialized = serde_json::to_value(&tool_use).unwrap();
919
920        assert_eq!(serialized["id"], "test_id");
921        assert_eq!(serialized["name"], "test_tool");
922        assert_eq!(serialized["thought_signature"], "test_signature");
923    }
924
925    #[test]
926    fn test_language_model_tool_use_deserializes_with_missing_signature() {
927        use serde_json::json;
928
929        let json = json!({
930            "id": "test_id",
931            "name": "test_tool",
932            "raw_input": "{\"arg\":\"value\"}",
933            "input": {"arg": "value"},
934            "is_input_complete": true
935        });
936
937        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
938
939        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
940        assert_eq!(tool_use.name.as_ref(), "test_tool");
941        assert_eq!(tool_use.thought_signature, None);
942    }
943
944    #[test]
945    fn test_language_model_tool_use_round_trip_with_signature() {
946        use serde_json::json;
947
948        let original = LanguageModelToolUse {
949            id: LanguageModelToolUseId::from("round_trip_id"),
950            name: "round_trip_tool".into(),
951            raw_input: json!({"key": "value"}).to_string(),
952            input: json!({"key": "value"}),
953            is_input_complete: true,
954            thought_signature: Some("round_trip_sig".to_string()),
955        };
956
957        let serialized = serde_json::to_value(&original).unwrap();
958        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
959
960        assert_eq!(deserialized.id, original.id);
961        assert_eq!(deserialized.name, original.name);
962        assert_eq!(deserialized.thought_signature, original.thought_signature);
963    }
964
965    #[test]
966    fn test_language_model_tool_use_round_trip_without_signature() {
967        use serde_json::json;
968
969        let original = LanguageModelToolUse {
970            id: LanguageModelToolUseId::from("no_sig_id"),
971            name: "no_sig_tool".into(),
972            raw_input: json!({"arg": "value"}).to_string(),
973            input: json!({"arg": "value"}),
974            is_input_complete: true,
975            thought_signature: None,
976        };
977
978        let serialized = serde_json::to_value(&original).unwrap();
979        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
980
981        assert_eq!(deserialized.id, original.id);
982        assert_eq!(deserialized.name, original.name);
983        assert_eq!(deserialized.thought_signature, None);
984    }
985}