language_model.rs

  1mod api_key;
  2mod model;
  3mod provider;
  4mod rate_limiter;
  5mod registry;
  6mod request;
  7mod role;
  8pub mod tool_schema;
  9
 10#[cfg(any(test, feature = "test-support"))]
 11pub mod fake_provider;
 12
 13use anyhow::{Result, anyhow};
 14use client::Client;
 15use client::UserStore;
 16use cloud_llm_client::CompletionRequestStatus;
 17use futures::FutureExt;
 18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 19use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
 20use http_client::{StatusCode, http};
 21use icons::IconName;
 22use parking_lot::Mutex;
 23use serde::{Deserialize, Serialize};
 24use std::ops::{Add, Sub};
 25use std::str::FromStr;
 26use std::sync::Arc;
 27use std::time::Duration;
 28use std::{fmt, io};
 29use thiserror::Error;
 30use util::serde::is_default;
 31
 32pub use crate::api_key::{ApiKey, ApiKeyState};
 33pub use crate::model::*;
 34pub use crate::rate_limiter::*;
 35pub use crate::registry::*;
 36pub use crate::request::*;
 37pub use crate::role::*;
 38pub use crate::tool_schema::LanguageModelToolSchemaFormat;
 39pub use provider::*;
 40pub use zed_env_vars::{EnvVar, env_var};
 41
 42pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 43    init_settings(cx);
 44    RefreshLlmTokenListener::register(client, user_store, cx);
 45}
 46
 47pub fn init_settings(cx: &mut App) {
 48    registry::init(cx);
 49}
 50
 51#[derive(Clone, Debug)]
 52pub struct LanguageModelCacheConfiguration {
 53    pub max_cache_anchors: usize,
 54    pub should_speculate: bool,
 55    pub min_total_token: u64,
 56}
 57
 58/// A completion event from a language model.
 59#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 60pub enum LanguageModelCompletionEvent {
 61    Queued {
 62        position: usize,
 63    },
 64    Started,
 65    Stop(StopReason),
 66    Text(String),
 67    Thinking {
 68        text: String,
 69        signature: Option<String>,
 70    },
 71    RedactedThinking {
 72        data: String,
 73    },
 74    ToolUse(LanguageModelToolUse),
 75    ToolUseJsonParseError {
 76        id: LanguageModelToolUseId,
 77        tool_name: Arc<str>,
 78        raw_input: Arc<str>,
 79        json_parse_error: String,
 80    },
 81    StartMessage {
 82        message_id: String,
 83    },
 84    ReasoningDetails(serde_json::Value),
 85    UsageUpdate(TokenUsage),
 86}
 87
 88impl LanguageModelCompletionEvent {
 89    pub fn from_completion_request_status(
 90        status: CompletionRequestStatus,
 91        upstream_provider: LanguageModelProviderName,
 92    ) -> Result<Option<Self>, LanguageModelCompletionError> {
 93        match status {
 94            CompletionRequestStatus::Queued { position } => {
 95                Ok(Some(LanguageModelCompletionEvent::Queued { position }))
 96            }
 97            CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
 98            CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
 99            CompletionRequestStatus::Failed {
100                code,
101                message,
102                request_id: _,
103                retry_after,
104            } => Err(LanguageModelCompletionError::from_cloud_failure(
105                upstream_provider,
106                code,
107                message,
108                retry_after.map(Duration::from_secs_f64),
109            )),
110        }
111    }
112}
113
114#[derive(Error, Debug)]
115pub enum LanguageModelCompletionError {
116    #[error("prompt too large for context window")]
117    PromptTooLarge { tokens: Option<u64> },
118    #[error("missing {provider} API key")]
119    NoApiKey { provider: LanguageModelProviderName },
120    #[error("{provider}'s API rate limit exceeded")]
121    RateLimitExceeded {
122        provider: LanguageModelProviderName,
123        retry_after: Option<Duration>,
124    },
125    #[error("{provider}'s API servers are overloaded right now")]
126    ServerOverloaded {
127        provider: LanguageModelProviderName,
128        retry_after: Option<Duration>,
129    },
130    #[error("{provider}'s API server reported an internal server error: {message}")]
131    ApiInternalServerError {
132        provider: LanguageModelProviderName,
133        message: String,
134    },
135    #[error("{message}")]
136    UpstreamProviderError {
137        message: String,
138        status: StatusCode,
139        retry_after: Option<Duration>,
140    },
141    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
142    HttpResponseError {
143        provider: LanguageModelProviderName,
144        status_code: StatusCode,
145        message: String,
146    },
147
148    // Client errors
149    #[error("invalid request format to {provider}'s API: {message}")]
150    BadRequestFormat {
151        provider: LanguageModelProviderName,
152        message: String,
153    },
154    #[error("authentication error with {provider}'s API: {message}")]
155    AuthenticationError {
156        provider: LanguageModelProviderName,
157        message: String,
158    },
159    #[error("Permission error with {provider}'s API: {message}")]
160    PermissionError {
161        provider: LanguageModelProviderName,
162        message: String,
163    },
164    #[error("language model provider API endpoint not found")]
165    ApiEndpointNotFound { provider: LanguageModelProviderName },
166    #[error("I/O error reading response from {provider}'s API")]
167    ApiReadResponseError {
168        provider: LanguageModelProviderName,
169        #[source]
170        error: io::Error,
171    },
172    #[error("error serializing request to {provider} API")]
173    SerializeRequest {
174        provider: LanguageModelProviderName,
175        #[source]
176        error: serde_json::Error,
177    },
178    #[error("error building request body to {provider} API")]
179    BuildRequestBody {
180        provider: LanguageModelProviderName,
181        #[source]
182        error: http::Error,
183    },
184    #[error("error sending HTTP request to {provider} API")]
185    HttpSend {
186        provider: LanguageModelProviderName,
187        #[source]
188        error: anyhow::Error,
189    },
190    #[error("error deserializing {provider} API response")]
191    DeserializeResponse {
192        provider: LanguageModelProviderName,
193        #[source]
194        error: serde_json::Error,
195    },
196
197    #[error("stream from {provider} ended unexpectedly")]
198    StreamEndedUnexpectedly { provider: LanguageModelProviderName },
199
200    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
201    #[error(transparent)]
202    Other(#[from] anyhow::Error),
203}
204
205impl LanguageModelCompletionError {
206    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
207        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
208        let upstream_status = error_json
209            .get("upstream_status")
210            .and_then(|v| v.as_u64())
211            .and_then(|status| u16::try_from(status).ok())
212            .and_then(|status| StatusCode::from_u16(status).ok())?;
213        let inner_message = error_json
214            .get("message")
215            .and_then(|v| v.as_str())
216            .unwrap_or(message)
217            .to_string();
218        Some((upstream_status, inner_message))
219    }
220
221    pub fn from_cloud_failure(
222        upstream_provider: LanguageModelProviderName,
223        code: String,
224        message: String,
225        retry_after: Option<Duration>,
226    ) -> Self {
227        if let Some(tokens) = parse_prompt_too_long(&message) {
228            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
229            // to be reported. This is a temporary workaround to handle this in the case where the
230            // token limit has been exceeded.
231            Self::PromptTooLarge {
232                tokens: Some(tokens),
233            }
234        } else if code == "upstream_http_error" {
235            if let Some((upstream_status, inner_message)) =
236                Self::parse_upstream_error_json(&message)
237            {
238                return Self::from_http_status(
239                    upstream_provider,
240                    upstream_status,
241                    inner_message,
242                    retry_after,
243                );
244            }
245            anyhow!("completion request failed, code: {code}, message: {message}").into()
246        } else if let Some(status_code) = code
247            .strip_prefix("upstream_http_")
248            .and_then(|code| StatusCode::from_str(code).ok())
249        {
250            Self::from_http_status(upstream_provider, status_code, message, retry_after)
251        } else if let Some(status_code) = code
252            .strip_prefix("http_")
253            .and_then(|code| StatusCode::from_str(code).ok())
254        {
255            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
256        } else {
257            anyhow!("completion request failed, code: {code}, message: {message}").into()
258        }
259    }
260
261    pub fn from_http_status(
262        provider: LanguageModelProviderName,
263        status_code: StatusCode,
264        message: String,
265        retry_after: Option<Duration>,
266    ) -> Self {
267        match status_code {
268            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
269            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
270            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
271            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
272            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
273                tokens: parse_prompt_too_long(&message),
274            },
275            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
276                provider,
277                retry_after,
278            },
279            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
280            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
281                provider,
282                retry_after,
283            },
284            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
285                provider,
286                retry_after,
287            },
288            _ => Self::HttpResponseError {
289                provider,
290                status_code,
291                message,
292            },
293        }
294    }
295}
296
297#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
298#[serde(rename_all = "snake_case")]
299pub enum StopReason {
300    EndTurn,
301    MaxTokens,
302    ToolUse,
303    Refusal,
304}
305
306#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
307pub struct TokenUsage {
308    #[serde(default, skip_serializing_if = "is_default")]
309    pub input_tokens: u64,
310    #[serde(default, skip_serializing_if = "is_default")]
311    pub output_tokens: u64,
312    #[serde(default, skip_serializing_if = "is_default")]
313    pub cache_creation_input_tokens: u64,
314    #[serde(default, skip_serializing_if = "is_default")]
315    pub cache_read_input_tokens: u64,
316}
317
318impl TokenUsage {
319    pub fn total_tokens(&self) -> u64 {
320        self.input_tokens
321            + self.output_tokens
322            + self.cache_read_input_tokens
323            + self.cache_creation_input_tokens
324    }
325}
326
327impl Add<TokenUsage> for TokenUsage {
328    type Output = Self;
329
330    fn add(self, other: Self) -> Self {
331        Self {
332            input_tokens: self.input_tokens + other.input_tokens,
333            output_tokens: self.output_tokens + other.output_tokens,
334            cache_creation_input_tokens: self.cache_creation_input_tokens
335                + other.cache_creation_input_tokens,
336            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
337        }
338    }
339}
340
341impl Sub<TokenUsage> for TokenUsage {
342    type Output = Self;
343
344    fn sub(self, other: Self) -> Self {
345        Self {
346            input_tokens: self.input_tokens - other.input_tokens,
347            output_tokens: self.output_tokens - other.output_tokens,
348            cache_creation_input_tokens: self.cache_creation_input_tokens
349                - other.cache_creation_input_tokens,
350            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
351        }
352    }
353}
354
355#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
356pub struct LanguageModelToolUseId(Arc<str>);
357
358impl fmt::Display for LanguageModelToolUseId {
359    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360        write!(f, "{}", self.0)
361    }
362}
363
364impl<T> From<T> for LanguageModelToolUseId
365where
366    T: Into<Arc<str>>,
367{
368    fn from(value: T) -> Self {
369        Self(value.into())
370    }
371}
372
373#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
374pub struct LanguageModelToolUse {
375    pub id: LanguageModelToolUseId,
376    pub name: Arc<str>,
377    pub raw_input: String,
378    pub input: serde_json::Value,
379    pub is_input_complete: bool,
380    /// Thought signature the model sent us. Some models require that this
381    /// signature be preserved and sent back in conversation history for validation.
382    pub thought_signature: Option<String>,
383}
384
385pub struct LanguageModelTextStream {
386    pub message_id: Option<String>,
387    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
388    // Has complete token usage after the stream has finished
389    pub last_token_usage: Arc<Mutex<TokenUsage>>,
390}
391
392impl Default for LanguageModelTextStream {
393    fn default() -> Self {
394        Self {
395            message_id: None,
396            stream: Box::pin(futures::stream::empty()),
397            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
398        }
399    }
400}
401
402#[derive(Debug, Clone)]
403pub struct LanguageModelEffortLevel {
404    pub name: SharedString,
405    pub value: SharedString,
406    pub is_default: bool,
407}
408
409pub trait LanguageModel: Send + Sync {
410    fn id(&self) -> LanguageModelId;
411    fn name(&self) -> LanguageModelName;
412    fn provider_id(&self) -> LanguageModelProviderId;
413    fn provider_name(&self) -> LanguageModelProviderName;
414    fn upstream_provider_id(&self) -> LanguageModelProviderId {
415        self.provider_id()
416    }
417    fn upstream_provider_name(&self) -> LanguageModelProviderName {
418        self.provider_name()
419    }
420
421    /// Returns whether this model is the "latest", so we can highlight it in the UI.
422    fn is_latest(&self) -> bool {
423        false
424    }
425
426    fn telemetry_id(&self) -> String;
427
428    fn api_key(&self, _cx: &App) -> Option<String> {
429        None
430    }
431
432    /// Information about the cost of using this model, if available.
433    fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
434        None
435    }
436
437    /// Whether this model supports thinking.
438    fn supports_thinking(&self) -> bool {
439        false
440    }
441
442    fn supports_fast_mode(&self) -> bool {
443        false
444    }
445
446    /// Returns the list of supported effort levels that can be used when thinking.
447    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
448        Vec::new()
449    }
450
451    /// Returns the default effort level to use when thinking.
452    fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
453        self.supported_effort_levels()
454            .into_iter()
455            .find(|effort_level| effort_level.is_default)
456    }
457
458    /// Whether this model supports images
459    fn supports_images(&self) -> bool;
460
461    /// Whether this model supports tools.
462    fn supports_tools(&self) -> bool;
463
464    /// Whether this model supports choosing which tool to use.
465    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
466
467    /// Returns whether this model or provider supports streaming tool calls;
468    fn supports_streaming_tools(&self) -> bool {
469        false
470    }
471
472    /// Returns whether this model/provider reports accurate split input/output token counts.
473    /// When true, the UI may show separate input/output token indicators.
474    fn supports_split_token_display(&self) -> bool {
475        false
476    }
477
478    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
479        LanguageModelToolSchemaFormat::JsonSchema
480    }
481
482    fn max_token_count(&self) -> u64;
483    fn max_output_tokens(&self) -> Option<u64> {
484        None
485    }
486
487    fn count_tokens(
488        &self,
489        request: LanguageModelRequest,
490        cx: &App,
491    ) -> BoxFuture<'static, Result<u64>>;
492
493    fn stream_completion(
494        &self,
495        request: LanguageModelRequest,
496        cx: &AsyncApp,
497    ) -> BoxFuture<
498        'static,
499        Result<
500            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
501            LanguageModelCompletionError,
502        >,
503    >;
504
505    fn stream_completion_text(
506        &self,
507        request: LanguageModelRequest,
508        cx: &AsyncApp,
509    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
510        let future = self.stream_completion(request, cx);
511
512        async move {
513            let events = future.await?;
514            let mut events = events.fuse();
515            let mut message_id = None;
516            let mut first_item_text = None;
517            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
518
519            if let Some(first_event) = events.next().await {
520                match first_event {
521                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
522                        message_id = Some(id);
523                    }
524                    Ok(LanguageModelCompletionEvent::Text(text)) => {
525                        first_item_text = Some(text);
526                    }
527                    _ => (),
528                }
529            }
530
531            let stream = futures::stream::iter(first_item_text.map(Ok))
532                .chain(events.filter_map({
533                    let last_token_usage = last_token_usage.clone();
534                    move |result| {
535                        let last_token_usage = last_token_usage.clone();
536                        async move {
537                            match result {
538                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
539                                Ok(LanguageModelCompletionEvent::Started) => None,
540                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
541                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
542                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
543                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
544                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
545                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
546                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
547                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
548                                    ..
549                                }) => None,
550                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
551                                    *last_token_usage.lock() = token_usage;
552                                    None
553                                }
554                                Err(err) => Some(Err(err)),
555                            }
556                        }
557                    }
558                }))
559                .boxed();
560
561            Ok(LanguageModelTextStream {
562                message_id,
563                stream,
564                last_token_usage,
565            })
566        }
567        .boxed()
568    }
569
570    fn stream_completion_tool(
571        &self,
572        request: LanguageModelRequest,
573        cx: &AsyncApp,
574    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
575        let future = self.stream_completion(request, cx);
576
577        async move {
578            let events = future.await?;
579            let mut events = events.fuse();
580
581            // Iterate through events until we find a complete ToolUse
582            while let Some(event) = events.next().await {
583                match event {
584                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
585                        if tool_use.is_input_complete =>
586                    {
587                        return Ok(tool_use);
588                    }
589                    Err(err) => {
590                        return Err(err);
591                    }
592                    _ => {}
593                }
594            }
595
596            // Stream ended without a complete tool use
597            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
598                "Stream ended without receiving a complete tool use"
599            )))
600        }
601        .boxed()
602    }
603
604    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
605        None
606    }
607
608    #[cfg(any(test, feature = "test-support"))]
609    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
610        unimplemented!()
611    }
612}
613
614impl std::fmt::Debug for dyn LanguageModel {
615    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
616        f.debug_struct("<dyn LanguageModel>")
617            .field("id", &self.id())
618            .field("name", &self.name())
619            .field("provider_id", &self.provider_id())
620            .field("provider_name", &self.provider_name())
621            .field("upstream_provider_name", &self.upstream_provider_name())
622            .field("upstream_provider_id", &self.upstream_provider_id())
623            .field("upstream_provider_id", &self.upstream_provider_id())
624            .field("supports_streaming_tools", &self.supports_streaming_tools())
625            .finish()
626    }
627}
628
629/// An error that occurred when trying to authenticate the language model provider.
630#[derive(Debug, Error)]
631pub enum AuthenticateError {
632    #[error("connection refused")]
633    ConnectionRefused,
634    #[error("credentials not found")]
635    CredentialsNotFound,
636    #[error(transparent)]
637    Other(#[from] anyhow::Error),
638}
639
640/// Either a built-in icon name or a path to an external SVG.
641#[derive(Debug, Clone, PartialEq, Eq)]
642pub enum IconOrSvg {
643    /// A built-in icon from Zed's icon set.
644    Icon(IconName),
645    /// Path to a custom SVG icon file.
646    Svg(SharedString),
647}
648
649impl Default for IconOrSvg {
650    fn default() -> Self {
651        Self::Icon(IconName::ZedAssistant)
652    }
653}
654
655pub trait LanguageModelProvider: 'static {
656    fn id(&self) -> LanguageModelProviderId;
657    fn name(&self) -> LanguageModelProviderName;
658    fn icon(&self) -> IconOrSvg {
659        IconOrSvg::default()
660    }
661    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
662    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
663    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
664    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
665        Vec::new()
666    }
667    fn is_authenticated(&self, cx: &App) -> bool;
668    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
669    fn configuration_view(
670        &self,
671        target_agent: ConfigurationViewTargetAgent,
672        window: &mut Window,
673        cx: &mut App,
674    ) -> AnyView;
675    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
676}
677
678#[derive(Default, Clone, PartialEq, Eq)]
679pub enum ConfigurationViewTargetAgent {
680    #[default]
681    ZedAgent,
682    Other(SharedString),
683}
684
685pub trait LanguageModelProviderState: 'static {
686    type ObservableEntity;
687
688    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
689
690    fn subscribe<T: 'static>(
691        &self,
692        cx: &mut gpui::Context<T>,
693        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
694    ) -> Option<gpui::Subscription> {
695        let entity = self.observable_entity()?;
696        Some(cx.observe(&entity, move |this, _, cx| {
697            callback(this, cx);
698        }))
699    }
700}
701
702#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
703pub struct LanguageModelId(pub SharedString);
704
705#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
706pub struct LanguageModelName(pub SharedString);
707
708#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
709pub struct LanguageModelProviderId(pub SharedString);
710
711#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
712pub struct LanguageModelProviderName(pub SharedString);
713
714#[derive(Clone, Debug, PartialEq)]
715pub enum LanguageModelCostInfo {
716    /// Cost per 1,000 input and output tokens
717    TokenCost {
718        input_token_cost_per_1m: f64,
719        output_token_cost_per_1m: f64,
720    },
721    /// Cost per request
722    RequestCost { cost_per_request: f64 },
723}
724
725impl LanguageModelCostInfo {
726    pub fn to_shared_string(&self) -> SharedString {
727        match self {
728            LanguageModelCostInfo::RequestCost { cost_per_request } => {
729                let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
730                SharedString::from(cost_str)
731            }
732            LanguageModelCostInfo::TokenCost {
733                input_token_cost_per_1m,
734                output_token_cost_per_1m,
735            } => {
736                let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
737                let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
738                SharedString::from(format!("{}$/{}$", input_cost, output_cost))
739            }
740        }
741    }
742
743    fn cost_value_to_string(cost: &f64) -> SharedString {
744        if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
745            SharedString::from(format!("{:.0}", cost))
746        } else {
747            SharedString::from(format!("{:.2}", cost))
748        }
749    }
750}
751
752impl LanguageModelProviderId {
753    pub const fn new(id: &'static str) -> Self {
754        Self(SharedString::new_static(id))
755    }
756}
757
758impl LanguageModelProviderName {
759    pub const fn new(id: &'static str) -> Self {
760        Self(SharedString::new_static(id))
761    }
762}
763
764impl fmt::Display for LanguageModelProviderId {
765    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
766        write!(f, "{}", self.0)
767    }
768}
769
770impl fmt::Display for LanguageModelProviderName {
771    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
772        write!(f, "{}", self.0)
773    }
774}
775
776impl From<String> for LanguageModelId {
777    fn from(value: String) -> Self {
778        Self(SharedString::from(value))
779    }
780}
781
782impl From<String> for LanguageModelName {
783    fn from(value: String) -> Self {
784        Self(SharedString::from(value))
785    }
786}
787
788impl From<String> for LanguageModelProviderId {
789    fn from(value: String) -> Self {
790        Self(SharedString::from(value))
791    }
792}
793
794impl From<String> for LanguageModelProviderName {
795    fn from(value: String) -> Self {
796        Self(SharedString::from(value))
797    }
798}
799
800impl From<Arc<str>> for LanguageModelProviderId {
801    fn from(value: Arc<str>) -> Self {
802        Self(SharedString::from(value))
803    }
804}
805
806impl From<Arc<str>> for LanguageModelProviderName {
807    fn from(value: Arc<str>) -> Self {
808        Self(SharedString::from(value))
809    }
810}
811
812#[cfg(test)]
813mod tests {
814    use super::*;
815
816    #[test]
817    fn test_from_cloud_failure_with_upstream_http_error() {
818        let error = LanguageModelCompletionError::from_cloud_failure(
819            String::from("anthropic").into(),
820            "upstream_http_error".to_string(),
821            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
822            None,
823        );
824
825        match error {
826            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
827                assert_eq!(provider.0, "anthropic");
828            }
829            _ => panic!(
830                "Expected ServerOverloaded error for 503 status, got: {:?}",
831                error
832            ),
833        }
834
835        let error = LanguageModelCompletionError::from_cloud_failure(
836            String::from("anthropic").into(),
837            "upstream_http_error".to_string(),
838            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
839            None,
840        );
841
842        match error {
843            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
844                assert_eq!(provider.0, "anthropic");
845                assert_eq!(message, "Internal server error");
846            }
847            _ => panic!(
848                "Expected ApiInternalServerError for 500 status, got: {:?}",
849                error
850            ),
851        }
852    }
853
854    #[test]
855    fn test_from_cloud_failure_with_standard_format() {
856        let error = LanguageModelCompletionError::from_cloud_failure(
857            String::from("anthropic").into(),
858            "upstream_http_503".to_string(),
859            "Service unavailable".to_string(),
860            None,
861        );
862
863        match error {
864            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
865                assert_eq!(provider.0, "anthropic");
866            }
867            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
868        }
869    }
870
871    #[test]
872    fn test_upstream_http_error_connection_timeout() {
873        let error = LanguageModelCompletionError::from_cloud_failure(
874            String::from("anthropic").into(),
875            "upstream_http_error".to_string(),
876            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
877            None,
878        );
879
880        match error {
881            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
882                assert_eq!(provider.0, "anthropic");
883            }
884            _ => panic!(
885                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
886                error
887            ),
888        }
889
890        let error = LanguageModelCompletionError::from_cloud_failure(
891            String::from("anthropic").into(),
892            "upstream_http_error".to_string(),
893            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
894            None,
895        );
896
897        match error {
898            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
899                assert_eq!(provider.0, "anthropic");
900                assert_eq!(
901                    message,
902                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
903                );
904            }
905            _ => panic!(
906                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
907                error
908            ),
909        }
910    }
911
912    #[test]
913    fn test_language_model_tool_use_serializes_with_signature() {
914        use serde_json::json;
915
916        let tool_use = LanguageModelToolUse {
917            id: LanguageModelToolUseId::from("test_id"),
918            name: "test_tool".into(),
919            raw_input: json!({"arg": "value"}).to_string(),
920            input: json!({"arg": "value"}),
921            is_input_complete: true,
922            thought_signature: Some("test_signature".to_string()),
923        };
924
925        let serialized = serde_json::to_value(&tool_use).unwrap();
926
927        assert_eq!(serialized["id"], "test_id");
928        assert_eq!(serialized["name"], "test_tool");
929        assert_eq!(serialized["thought_signature"], "test_signature");
930    }
931
932    #[test]
933    fn test_language_model_tool_use_deserializes_with_missing_signature() {
934        use serde_json::json;
935
936        let json = json!({
937            "id": "test_id",
938            "name": "test_tool",
939            "raw_input": "{\"arg\":\"value\"}",
940            "input": {"arg": "value"},
941            "is_input_complete": true
942        });
943
944        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
945
946        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
947        assert_eq!(tool_use.name.as_ref(), "test_tool");
948        assert_eq!(tool_use.thought_signature, None);
949    }
950
951    #[test]
952    fn test_language_model_tool_use_round_trip_with_signature() {
953        use serde_json::json;
954
955        let original = LanguageModelToolUse {
956            id: LanguageModelToolUseId::from("round_trip_id"),
957            name: "round_trip_tool".into(),
958            raw_input: json!({"key": "value"}).to_string(),
959            input: json!({"key": "value"}),
960            is_input_complete: true,
961            thought_signature: Some("round_trip_sig".to_string()),
962        };
963
964        let serialized = serde_json::to_value(&original).unwrap();
965        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
966
967        assert_eq!(deserialized.id, original.id);
968        assert_eq!(deserialized.name, original.name);
969        assert_eq!(deserialized.thought_signature, original.thought_signature);
970    }
971
972    #[test]
973    fn test_language_model_tool_use_round_trip_without_signature() {
974        use serde_json::json;
975
976        let original = LanguageModelToolUse {
977            id: LanguageModelToolUseId::from("no_sig_id"),
978            name: "no_sig_tool".into(),
979            raw_input: json!({"arg": "value"}).to_string(),
980            input: json!({"arg": "value"}),
981            is_input_complete: true,
982            thought_signature: None,
983        };
984
985        let serialized = serde_json::to_value(&original).unwrap();
986        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
987
988        assert_eq!(deserialized.id, original.id);
989        assert_eq!(deserialized.name, original.name);
990        assert_eq!(deserialized.thought_signature, None);
991    }
992}