language_model.rs

  1mod api_key;
  2mod model;
  3pub mod provider;
  4mod rate_limiter;
  5mod registry;
  6mod request;
  7mod role;
  8pub mod tool_schema;
  9
 10#[cfg(any(test, feature = "test-support"))]
 11pub mod fake_provider;
 12
 13use anyhow::{Result, anyhow};
 14use client::Client;
 15use client::UserStore;
 16use cloud_llm_client::CompletionRequestStatus;
 17use futures::FutureExt;
 18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 19use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
 20use http_client::{StatusCode, http};
 21use icons::IconName;
 22use parking_lot::Mutex;
 23use provider::parse_prompt_too_long;
 24use serde::{Deserialize, Serialize};
 25pub use settings::LanguageModelCacheConfiguration;
 26use std::ops::{Add, Sub};
 27use std::str::FromStr;
 28use std::sync::Arc;
 29use std::time::Duration;
 30use std::{fmt, io};
 31use thiserror::Error;
 32use util::serde::is_default;
 33
 34pub use crate::api_key::{ApiKey, ApiKeyState};
 35pub use crate::model::*;
 36pub use crate::rate_limiter::*;
 37pub use crate::registry::*;
 38pub use crate::request::*;
 39pub use crate::role::*;
 40pub use crate::tool_schema::LanguageModelToolSchemaFormat;
 41pub use zed_env_vars::{EnvVar, env_var};
 42
 43pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
 44    init_settings(cx);
 45    RefreshLlmTokenListener::register(client, user_store, cx);
 46}
 47
 48pub fn init_settings(cx: &mut App) {
 49    registry::init(cx);
 50}
 51
 52/// A completion event from a language model.
 53#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 54pub enum LanguageModelCompletionEvent {
 55    Queued {
 56        position: usize,
 57    },
 58    Started,
 59    Stop(StopReason),
 60    Text(String),
 61    Thinking {
 62        text: String,
 63        signature: Option<String>,
 64    },
 65    RedactedThinking {
 66        data: String,
 67    },
 68    ToolUse(LanguageModelToolUse),
 69    ToolUseJsonParseError {
 70        id: LanguageModelToolUseId,
 71        tool_name: Arc<str>,
 72        raw_input: Arc<str>,
 73        json_parse_error: String,
 74    },
 75    StartMessage {
 76        message_id: String,
 77    },
 78    ReasoningDetails(serde_json::Value),
 79    UsageUpdate(TokenUsage),
 80}
 81
 82impl LanguageModelCompletionEvent {
 83    pub fn from_completion_request_status(
 84        status: CompletionRequestStatus,
 85        upstream_provider: LanguageModelProviderName,
 86    ) -> Result<Option<Self>, LanguageModelCompletionError> {
 87        match status {
 88            CompletionRequestStatus::Queued { position } => {
 89                Ok(Some(LanguageModelCompletionEvent::Queued { position }))
 90            }
 91            CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
 92            CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
 93            CompletionRequestStatus::Failed {
 94                code,
 95                message,
 96                request_id: _,
 97                retry_after,
 98            } => Err(LanguageModelCompletionError::from_cloud_failure(
 99                upstream_provider,
100                code,
101                message,
102                retry_after.map(Duration::from_secs_f64),
103            )),
104        }
105    }
106}
107
108#[derive(Error, Debug)]
109pub enum LanguageModelCompletionError {
110    #[error("prompt too large for context window")]
111    PromptTooLarge { tokens: Option<u64> },
112    #[error("missing {provider} API key")]
113    NoApiKey { provider: LanguageModelProviderName },
114    #[error("{provider}'s API rate limit exceeded")]
115    RateLimitExceeded {
116        provider: LanguageModelProviderName,
117        retry_after: Option<Duration>,
118    },
119    #[error("{provider}'s API servers are overloaded right now")]
120    ServerOverloaded {
121        provider: LanguageModelProviderName,
122        retry_after: Option<Duration>,
123    },
124    #[error("{provider}'s API server reported an internal server error: {message}")]
125    ApiInternalServerError {
126        provider: LanguageModelProviderName,
127        message: String,
128    },
129    #[error("{message}")]
130    UpstreamProviderError {
131        message: String,
132        status: StatusCode,
133        retry_after: Option<Duration>,
134    },
135    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
136    HttpResponseError {
137        provider: LanguageModelProviderName,
138        status_code: StatusCode,
139        message: String,
140    },
141
142    // Client errors
143    #[error("invalid request format to {provider}'s API: {message}")]
144    BadRequestFormat {
145        provider: LanguageModelProviderName,
146        message: String,
147    },
148    #[error("authentication error with {provider}'s API: {message}")]
149    AuthenticationError {
150        provider: LanguageModelProviderName,
151        message: String,
152    },
153    #[error("Permission error with {provider}'s API: {message}")]
154    PermissionError {
155        provider: LanguageModelProviderName,
156        message: String,
157    },
158    #[error("language model provider API endpoint not found")]
159    ApiEndpointNotFound { provider: LanguageModelProviderName },
160    #[error("I/O error reading response from {provider}'s API")]
161    ApiReadResponseError {
162        provider: LanguageModelProviderName,
163        #[source]
164        error: io::Error,
165    },
166    #[error("error serializing request to {provider} API")]
167    SerializeRequest {
168        provider: LanguageModelProviderName,
169        #[source]
170        error: serde_json::Error,
171    },
172    #[error("error building request body to {provider} API")]
173    BuildRequestBody {
174        provider: LanguageModelProviderName,
175        #[source]
176        error: http::Error,
177    },
178    #[error("error sending HTTP request to {provider} API")]
179    HttpSend {
180        provider: LanguageModelProviderName,
181        #[source]
182        error: anyhow::Error,
183    },
184    #[error("error deserializing {provider} API response")]
185    DeserializeResponse {
186        provider: LanguageModelProviderName,
187        #[source]
188        error: serde_json::Error,
189    },
190
191    #[error("stream from {provider} ended unexpectedly")]
192    StreamEndedUnexpectedly { provider: LanguageModelProviderName },
193
194    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
195    #[error(transparent)]
196    Other(#[from] anyhow::Error),
197}
198
199impl LanguageModelCompletionError {
200    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
201        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
202        let upstream_status = error_json
203            .get("upstream_status")
204            .and_then(|v| v.as_u64())
205            .and_then(|status| u16::try_from(status).ok())
206            .and_then(|status| StatusCode::from_u16(status).ok())?;
207        let inner_message = error_json
208            .get("message")
209            .and_then(|v| v.as_str())
210            .unwrap_or(message)
211            .to_string();
212        Some((upstream_status, inner_message))
213    }
214
215    pub fn from_cloud_failure(
216        upstream_provider: LanguageModelProviderName,
217        code: String,
218        message: String,
219        retry_after: Option<Duration>,
220    ) -> Self {
221        if let Some(tokens) = parse_prompt_too_long(&message) {
222            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
223            // to be reported. This is a temporary workaround to handle this in the case where the
224            // token limit has been exceeded.
225            Self::PromptTooLarge {
226                tokens: Some(tokens),
227            }
228        } else if code == "upstream_http_error" {
229            if let Some((upstream_status, inner_message)) =
230                Self::parse_upstream_error_json(&message)
231            {
232                return Self::from_http_status(
233                    upstream_provider,
234                    upstream_status,
235                    inner_message,
236                    retry_after,
237                );
238            }
239            anyhow!("completion request failed, code: {code}, message: {message}").into()
240        } else if let Some(status_code) = code
241            .strip_prefix("upstream_http_")
242            .and_then(|code| StatusCode::from_str(code).ok())
243        {
244            Self::from_http_status(upstream_provider, status_code, message, retry_after)
245        } else if let Some(status_code) = code
246            .strip_prefix("http_")
247            .and_then(|code| StatusCode::from_str(code).ok())
248        {
249            Self::from_http_status(
250                provider::ZED_CLOUD_PROVIDER_NAME,
251                status_code,
252                message,
253                retry_after,
254            )
255        } else {
256            anyhow!("completion request failed, code: {code}, message: {message}").into()
257        }
258    }
259
260    pub fn from_http_status(
261        provider: LanguageModelProviderName,
262        status_code: StatusCode,
263        message: String,
264        retry_after: Option<Duration>,
265    ) -> Self {
266        match status_code {
267            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
268            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
269            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
270            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
271            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
272                tokens: parse_prompt_too_long(&message),
273            },
274            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
275                provider,
276                retry_after,
277            },
278            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
279            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
280                provider,
281                retry_after,
282            },
283            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
284                provider,
285                retry_after,
286            },
287            _ => Self::HttpResponseError {
288                provider,
289                status_code,
290                message,
291            },
292        }
293    }
294}
295
296#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
297#[serde(rename_all = "snake_case")]
298pub enum StopReason {
299    EndTurn,
300    MaxTokens,
301    ToolUse,
302    Refusal,
303}
304
305#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
306pub struct TokenUsage {
307    #[serde(default, skip_serializing_if = "is_default")]
308    pub input_tokens: u64,
309    #[serde(default, skip_serializing_if = "is_default")]
310    pub output_tokens: u64,
311    #[serde(default, skip_serializing_if = "is_default")]
312    pub cache_creation_input_tokens: u64,
313    #[serde(default, skip_serializing_if = "is_default")]
314    pub cache_read_input_tokens: u64,
315}
316
317impl TokenUsage {
318    pub fn total_tokens(&self) -> u64 {
319        self.input_tokens
320            + self.output_tokens
321            + self.cache_read_input_tokens
322            + self.cache_creation_input_tokens
323    }
324}
325
326impl Add<TokenUsage> for TokenUsage {
327    type Output = Self;
328
329    fn add(self, other: Self) -> Self {
330        Self {
331            input_tokens: self.input_tokens + other.input_tokens,
332            output_tokens: self.output_tokens + other.output_tokens,
333            cache_creation_input_tokens: self.cache_creation_input_tokens
334                + other.cache_creation_input_tokens,
335            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
336        }
337    }
338}
339
340impl Sub<TokenUsage> for TokenUsage {
341    type Output = Self;
342
343    fn sub(self, other: Self) -> Self {
344        Self {
345            input_tokens: self.input_tokens - other.input_tokens,
346            output_tokens: self.output_tokens - other.output_tokens,
347            cache_creation_input_tokens: self.cache_creation_input_tokens
348                - other.cache_creation_input_tokens,
349            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
350        }
351    }
352}
353
354#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
355pub struct LanguageModelToolUseId(Arc<str>);
356
357impl fmt::Display for LanguageModelToolUseId {
358    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359        write!(f, "{}", self.0)
360    }
361}
362
363impl<T> From<T> for LanguageModelToolUseId
364where
365    T: Into<Arc<str>>,
366{
367    fn from(value: T) -> Self {
368        Self(value.into())
369    }
370}
371
372#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
373pub struct LanguageModelToolUse {
374    pub id: LanguageModelToolUseId,
375    pub name: Arc<str>,
376    pub raw_input: String,
377    pub input: serde_json::Value,
378    pub is_input_complete: bool,
379    /// Thought signature the model sent us. Some models require that this
380    /// signature be preserved and sent back in conversation history for validation.
381    pub thought_signature: Option<String>,
382}
383
384pub struct LanguageModelTextStream {
385    pub message_id: Option<String>,
386    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
387    // Has complete token usage after the stream has finished
388    pub last_token_usage: Arc<Mutex<TokenUsage>>,
389}
390
391impl Default for LanguageModelTextStream {
392    fn default() -> Self {
393        Self {
394            message_id: None,
395            stream: Box::pin(futures::stream::empty()),
396            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
397        }
398    }
399}
400
401#[derive(Debug, Clone)]
402pub struct LanguageModelEffortLevel {
403    pub name: SharedString,
404    pub value: SharedString,
405    pub is_default: bool,
406}
407
408pub trait LanguageModel: Send + Sync {
409    fn id(&self) -> LanguageModelId;
410    fn name(&self) -> LanguageModelName;
411    fn provider_id(&self) -> LanguageModelProviderId;
412    fn provider_name(&self) -> LanguageModelProviderName;
413    fn upstream_provider_id(&self) -> LanguageModelProviderId {
414        self.provider_id()
415    }
416    fn upstream_provider_name(&self) -> LanguageModelProviderName {
417        self.provider_name()
418    }
419
420    /// Returns whether this model is the "latest", so we can highlight it in the UI.
421    fn is_latest(&self) -> bool {
422        false
423    }
424
425    fn telemetry_id(&self) -> String;
426
427    fn api_key(&self, _cx: &App) -> Option<String> {
428        None
429    }
430
431    /// Information about the cost of using this model, if available.
432    fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
433        None
434    }
435
436    /// Whether this model supports thinking.
437    fn supports_thinking(&self) -> bool {
438        false
439    }
440
441    fn supports_fast_mode(&self) -> bool {
442        false
443    }
444
445    /// Returns the list of supported effort levels that can be used when thinking.
446    fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
447        Vec::new()
448    }
449
450    /// Returns the default effort level to use when thinking.
451    fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
452        self.supported_effort_levels()
453            .into_iter()
454            .find(|effort_level| effort_level.is_default)
455    }
456
457    /// Whether this model supports images
458    fn supports_images(&self) -> bool;
459
460    /// Whether this model supports tools.
461    fn supports_tools(&self) -> bool;
462
463    /// Whether this model supports choosing which tool to use.
464    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
465
466    /// Returns whether this model or provider supports streaming tool calls;
467    fn supports_streaming_tools(&self) -> bool {
468        false
469    }
470
471    /// Returns whether this model/provider reports accurate split input/output token counts.
472    /// When true, the UI may show separate input/output token indicators.
473    fn supports_split_token_display(&self) -> bool {
474        false
475    }
476
477    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
478        LanguageModelToolSchemaFormat::JsonSchema
479    }
480
481    fn max_token_count(&self) -> u64;
482    fn max_output_tokens(&self) -> Option<u64> {
483        None
484    }
485
486    fn count_tokens(
487        &self,
488        request: LanguageModelRequest,
489        cx: &App,
490    ) -> BoxFuture<'static, Result<u64>>;
491
492    fn stream_completion(
493        &self,
494        request: LanguageModelRequest,
495        cx: &AsyncApp,
496    ) -> BoxFuture<
497        'static,
498        Result<
499            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
500            LanguageModelCompletionError,
501        >,
502    >;
503
504    fn stream_completion_text(
505        &self,
506        request: LanguageModelRequest,
507        cx: &AsyncApp,
508    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
509        let future = self.stream_completion(request, cx);
510
511        async move {
512            let events = future.await?;
513            let mut events = events.fuse();
514            let mut message_id = None;
515            let mut first_item_text = None;
516            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
517
518            if let Some(first_event) = events.next().await {
519                match first_event {
520                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
521                        message_id = Some(id);
522                    }
523                    Ok(LanguageModelCompletionEvent::Text(text)) => {
524                        first_item_text = Some(text);
525                    }
526                    _ => (),
527                }
528            }
529
530            let stream = futures::stream::iter(first_item_text.map(Ok))
531                .chain(events.filter_map({
532                    let last_token_usage = last_token_usage.clone();
533                    move |result| {
534                        let last_token_usage = last_token_usage.clone();
535                        async move {
536                            match result {
537                                Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
538                                Ok(LanguageModelCompletionEvent::Started) => None,
539                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
540                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
541                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
542                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
543                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
544                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
545                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
546                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
547                                    ..
548                                }) => None,
549                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
550                                    *last_token_usage.lock() = token_usage;
551                                    None
552                                }
553                                Err(err) => Some(Err(err)),
554                            }
555                        }
556                    }
557                }))
558                .boxed();
559
560            Ok(LanguageModelTextStream {
561                message_id,
562                stream,
563                last_token_usage,
564            })
565        }
566        .boxed()
567    }
568
569    fn stream_completion_tool(
570        &self,
571        request: LanguageModelRequest,
572        cx: &AsyncApp,
573    ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
574        let future = self.stream_completion(request, cx);
575
576        async move {
577            let events = future.await?;
578            let mut events = events.fuse();
579
580            // Iterate through events until we find a complete ToolUse
581            while let Some(event) = events.next().await {
582                match event {
583                    Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
584                        if tool_use.is_input_complete =>
585                    {
586                        return Ok(tool_use);
587                    }
588                    Err(err) => {
589                        return Err(err);
590                    }
591                    _ => {}
592                }
593            }
594
595            // Stream ended without a complete tool use
596            Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
597                "Stream ended without receiving a complete tool use"
598            )))
599        }
600        .boxed()
601    }
602
603    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
604        None
605    }
606
607    #[cfg(any(test, feature = "test-support"))]
608    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
609        unimplemented!()
610    }
611}
612
613impl std::fmt::Debug for dyn LanguageModel {
614    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
615        f.debug_struct("<dyn LanguageModel>")
616            .field("id", &self.id())
617            .field("name", &self.name())
618            .field("provider_id", &self.provider_id())
619            .field("provider_name", &self.provider_name())
620            .field("upstream_provider_name", &self.upstream_provider_name())
621            .field("upstream_provider_id", &self.upstream_provider_id())
622            .field("upstream_provider_id", &self.upstream_provider_id())
623            .field("supports_streaming_tools", &self.supports_streaming_tools())
624            .finish()
625    }
626}
627
628/// An error that occurred when trying to authenticate the language model provider.
629#[derive(Debug, Error)]
630pub enum AuthenticateError {
631    #[error("connection refused")]
632    ConnectionRefused,
633    #[error("credentials not found")]
634    CredentialsNotFound,
635    #[error(transparent)]
636    Other(#[from] anyhow::Error),
637}
638
639/// Either a built-in icon name or a path to an external SVG.
640#[derive(Debug, Clone, PartialEq, Eq)]
641pub enum IconOrSvg {
642    /// A built-in icon from Zed's icon set.
643    Icon(IconName),
644    /// Path to a custom SVG icon file.
645    Svg(SharedString),
646}
647
648impl Default for IconOrSvg {
649    fn default() -> Self {
650        Self::Icon(IconName::ZedAssistant)
651    }
652}
653
654pub trait LanguageModelProvider: 'static {
655    fn id(&self) -> LanguageModelProviderId;
656    fn name(&self) -> LanguageModelProviderName;
657    fn icon(&self) -> IconOrSvg {
658        IconOrSvg::default()
659    }
660    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
661    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
662    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
663    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
664        Vec::new()
665    }
666    fn is_authenticated(&self, cx: &App) -> bool;
667    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
668    fn configuration_view(
669        &self,
670        target_agent: ConfigurationViewTargetAgent,
671        window: &mut Window,
672        cx: &mut App,
673    ) -> AnyView;
674    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
675}
676
677#[derive(Default, Clone, PartialEq, Eq)]
678pub enum ConfigurationViewTargetAgent {
679    #[default]
680    ZedAgent,
681    Other(SharedString),
682}
683
684pub trait LanguageModelProviderState: 'static {
685    type ObservableEntity;
686
687    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
688
689    fn subscribe<T: 'static>(
690        &self,
691        cx: &mut gpui::Context<T>,
692        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
693    ) -> Option<gpui::Subscription> {
694        let entity = self.observable_entity()?;
695        Some(cx.observe(&entity, move |this, _, cx| {
696            callback(this, cx);
697        }))
698    }
699}
700
701#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
702pub struct LanguageModelId(pub SharedString);
703
704#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
705pub struct LanguageModelName(pub SharedString);
706
707#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
708pub struct LanguageModelProviderId(pub SharedString);
709
710#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
711pub struct LanguageModelProviderName(pub SharedString);
712
713#[derive(Clone, Debug, PartialEq)]
714pub enum LanguageModelCostInfo {
715    /// Cost per 1,000 input and output tokens
716    TokenCost {
717        input_token_cost_per_1m: f64,
718        output_token_cost_per_1m: f64,
719    },
720    /// Cost per request
721    RequestCost { cost_per_request: f64 },
722}
723
724impl LanguageModelCostInfo {
725    pub fn to_shared_string(&self) -> SharedString {
726        match self {
727            LanguageModelCostInfo::RequestCost { cost_per_request } => {
728                let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
729                SharedString::from(cost_str)
730            }
731            LanguageModelCostInfo::TokenCost {
732                input_token_cost_per_1m,
733                output_token_cost_per_1m,
734            } => {
735                let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
736                let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
737                SharedString::from(format!("{}$/{}$", input_cost, output_cost))
738            }
739        }
740    }
741
742    fn cost_value_to_string(cost: &f64) -> SharedString {
743        if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
744            SharedString::from(format!("{:.0}", cost))
745        } else {
746            SharedString::from(format!("{:.2}", cost))
747        }
748    }
749}
750
751impl LanguageModelProviderId {
752    pub const fn new(id: &'static str) -> Self {
753        Self(SharedString::new_static(id))
754    }
755}
756
757impl LanguageModelProviderName {
758    pub const fn new(id: &'static str) -> Self {
759        Self(SharedString::new_static(id))
760    }
761}
762
763impl fmt::Display for LanguageModelProviderId {
764    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
765        write!(f, "{}", self.0)
766    }
767}
768
769impl fmt::Display for LanguageModelProviderName {
770    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
771        write!(f, "{}", self.0)
772    }
773}
774
775impl From<String> for LanguageModelId {
776    fn from(value: String) -> Self {
777        Self(SharedString::from(value))
778    }
779}
780
781impl From<String> for LanguageModelName {
782    fn from(value: String) -> Self {
783        Self(SharedString::from(value))
784    }
785}
786
787impl From<String> for LanguageModelProviderId {
788    fn from(value: String) -> Self {
789        Self(SharedString::from(value))
790    }
791}
792
793impl From<String> for LanguageModelProviderName {
794    fn from(value: String) -> Self {
795        Self(SharedString::from(value))
796    }
797}
798
799impl From<Arc<str>> for LanguageModelProviderId {
800    fn from(value: Arc<str>) -> Self {
801        Self(SharedString::from(value))
802    }
803}
804
805impl From<Arc<str>> for LanguageModelProviderName {
806    fn from(value: Arc<str>) -> Self {
807        Self(SharedString::from(value))
808    }
809}
810
811#[cfg(test)]
812mod tests {
813    use super::*;
814
815    #[test]
816    fn test_from_cloud_failure_with_upstream_http_error() {
817        let error = LanguageModelCompletionError::from_cloud_failure(
818            String::from("anthropic").into(),
819            "upstream_http_error".to_string(),
820            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
821            None,
822        );
823
824        match error {
825            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
826                assert_eq!(provider.0, "anthropic");
827            }
828            _ => panic!(
829                "Expected ServerOverloaded error for 503 status, got: {:?}",
830                error
831            ),
832        }
833
834        let error = LanguageModelCompletionError::from_cloud_failure(
835            String::from("anthropic").into(),
836            "upstream_http_error".to_string(),
837            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
838            None,
839        );
840
841        match error {
842            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
843                assert_eq!(provider.0, "anthropic");
844                assert_eq!(message, "Internal server error");
845            }
846            _ => panic!(
847                "Expected ApiInternalServerError for 500 status, got: {:?}",
848                error
849            ),
850        }
851    }
852
853    #[test]
854    fn test_from_cloud_failure_with_standard_format() {
855        let error = LanguageModelCompletionError::from_cloud_failure(
856            String::from("anthropic").into(),
857            "upstream_http_503".to_string(),
858            "Service unavailable".to_string(),
859            None,
860        );
861
862        match error {
863            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
864                assert_eq!(provider.0, "anthropic");
865            }
866            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
867        }
868    }
869
870    #[test]
871    fn test_upstream_http_error_connection_timeout() {
872        let error = LanguageModelCompletionError::from_cloud_failure(
873            String::from("anthropic").into(),
874            "upstream_http_error".to_string(),
875            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
876            None,
877        );
878
879        match error {
880            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
881                assert_eq!(provider.0, "anthropic");
882            }
883            _ => panic!(
884                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
885                error
886            ),
887        }
888
889        let error = LanguageModelCompletionError::from_cloud_failure(
890            String::from("anthropic").into(),
891            "upstream_http_error".to_string(),
892            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
893            None,
894        );
895
896        match error {
897            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
898                assert_eq!(provider.0, "anthropic");
899                assert_eq!(
900                    message,
901                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
902                );
903            }
904            _ => panic!(
905                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
906                error
907            ),
908        }
909    }
910
911    #[test]
912    fn test_language_model_tool_use_serializes_with_signature() {
913        use serde_json::json;
914
915        let tool_use = LanguageModelToolUse {
916            id: LanguageModelToolUseId::from("test_id"),
917            name: "test_tool".into(),
918            raw_input: json!({"arg": "value"}).to_string(),
919            input: json!({"arg": "value"}),
920            is_input_complete: true,
921            thought_signature: Some("test_signature".to_string()),
922        };
923
924        let serialized = serde_json::to_value(&tool_use).unwrap();
925
926        assert_eq!(serialized["id"], "test_id");
927        assert_eq!(serialized["name"], "test_tool");
928        assert_eq!(serialized["thought_signature"], "test_signature");
929    }
930
931    #[test]
932    fn test_language_model_tool_use_deserializes_with_missing_signature() {
933        use serde_json::json;
934
935        let json = json!({
936            "id": "test_id",
937            "name": "test_tool",
938            "raw_input": "{\"arg\":\"value\"}",
939            "input": {"arg": "value"},
940            "is_input_complete": true
941        });
942
943        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
944
945        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
946        assert_eq!(tool_use.name.as_ref(), "test_tool");
947        assert_eq!(tool_use.thought_signature, None);
948    }
949
950    #[test]
951    fn test_language_model_tool_use_round_trip_with_signature() {
952        use serde_json::json;
953
954        let original = LanguageModelToolUse {
955            id: LanguageModelToolUseId::from("round_trip_id"),
956            name: "round_trip_tool".into(),
957            raw_input: json!({"key": "value"}).to_string(),
958            input: json!({"key": "value"}),
959            is_input_complete: true,
960            thought_signature: Some("round_trip_sig".to_string()),
961        };
962
963        let serialized = serde_json::to_value(&original).unwrap();
964        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
965
966        assert_eq!(deserialized.id, original.id);
967        assert_eq!(deserialized.name, original.name);
968        assert_eq!(deserialized.thought_signature, original.thought_signature);
969    }
970
971    #[test]
972    fn test_language_model_tool_use_round_trip_without_signature() {
973        use serde_json::json;
974
975        let original = LanguageModelToolUse {
976            id: LanguageModelToolUseId::from("no_sig_id"),
977            name: "no_sig_tool".into(),
978            raw_input: json!({"arg": "value"}).to_string(),
979            input: json!({"arg": "value"}),
980            is_input_complete: true,
981            thought_signature: None,
982        };
983
984        let serialized = serde_json::to_value(&original).unwrap();
985        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
986
987        assert_eq!(deserialized.id, original.id);
988        assert_eq!(deserialized.name, original.name);
989        assert_eq!(deserialized.thought_signature, None);
990    }
991}