1mod api_key;
2mod model;
3mod provider;
4mod rate_limiter;
5mod registry;
6mod request;
7mod role;
8pub mod tool_schema;
9
10#[cfg(any(test, feature = "test-support"))]
11pub mod fake_provider;
12
13use anyhow::{Result, anyhow};
14use cloud_llm_client::CompletionRequestStatus;
15use futures::FutureExt;
16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
18use http_client::{StatusCode, http};
19use icons::IconName;
20use parking_lot::Mutex;
21use serde::{Deserialize, Serialize};
22use std::ops::{Add, Sub};
23use std::str::FromStr;
24use std::sync::Arc;
25use std::time::Duration;
26use std::{fmt, io};
27use thiserror::Error;
28use util::serde::is_default;
29
30pub use crate::api_key::{ApiKey, ApiKeyState};
31pub use crate::model::*;
32pub use crate::rate_limiter::*;
33pub use crate::registry::*;
34pub use crate::request::*;
35pub use crate::role::*;
36pub use crate::tool_schema::LanguageModelToolSchemaFormat;
37pub use env_var::{EnvVar, env_var};
38pub use provider::*;
39
40pub fn init(cx: &mut App) {
41 registry::init(cx);
42}
43
44#[derive(Clone, Debug)]
45pub struct LanguageModelCacheConfiguration {
46 pub max_cache_anchors: usize,
47 pub should_speculate: bool,
48 pub min_total_token: u64,
49}
50
51/// A completion event from a language model.
52#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
53pub enum LanguageModelCompletionEvent {
54 Queued {
55 position: usize,
56 },
57 Started,
58 Stop(StopReason),
59 Text(String),
60 Thinking {
61 text: String,
62 signature: Option<String>,
63 },
64 RedactedThinking {
65 data: String,
66 },
67 ToolUse(LanguageModelToolUse),
68 ToolUseJsonParseError {
69 id: LanguageModelToolUseId,
70 tool_name: Arc<str>,
71 raw_input: Arc<str>,
72 json_parse_error: String,
73 },
74 StartMessage {
75 message_id: String,
76 },
77 ReasoningDetails(serde_json::Value),
78 UsageUpdate(TokenUsage),
79}
80
81impl LanguageModelCompletionEvent {
82 pub fn from_completion_request_status(
83 status: CompletionRequestStatus,
84 upstream_provider: LanguageModelProviderName,
85 ) -> Result<Option<Self>, LanguageModelCompletionError> {
86 match status {
87 CompletionRequestStatus::Queued { position } => {
88 Ok(Some(LanguageModelCompletionEvent::Queued { position }))
89 }
90 CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
91 CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
92 CompletionRequestStatus::Failed {
93 code,
94 message,
95 request_id: _,
96 retry_after,
97 } => Err(LanguageModelCompletionError::from_cloud_failure(
98 upstream_provider,
99 code,
100 message,
101 retry_after.map(Duration::from_secs_f64),
102 )),
103 }
104 }
105}
106
107#[derive(Error, Debug)]
108pub enum LanguageModelCompletionError {
109 #[error("prompt too large for context window")]
110 PromptTooLarge { tokens: Option<u64> },
111 #[error("missing {provider} API key")]
112 NoApiKey { provider: LanguageModelProviderName },
113 #[error("{provider}'s API rate limit exceeded")]
114 RateLimitExceeded {
115 provider: LanguageModelProviderName,
116 retry_after: Option<Duration>,
117 },
118 #[error("{provider}'s API servers are overloaded right now")]
119 ServerOverloaded {
120 provider: LanguageModelProviderName,
121 retry_after: Option<Duration>,
122 },
123 #[error("{provider}'s API server reported an internal server error: {message}")]
124 ApiInternalServerError {
125 provider: LanguageModelProviderName,
126 message: String,
127 },
128 #[error("{message}")]
129 UpstreamProviderError {
130 message: String,
131 status: StatusCode,
132 retry_after: Option<Duration>,
133 },
134 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
135 HttpResponseError {
136 provider: LanguageModelProviderName,
137 status_code: StatusCode,
138 message: String,
139 },
140
141 // Client errors
142 #[error("invalid request format to {provider}'s API: {message}")]
143 BadRequestFormat {
144 provider: LanguageModelProviderName,
145 message: String,
146 },
147 #[error("authentication error with {provider}'s API: {message}")]
148 AuthenticationError {
149 provider: LanguageModelProviderName,
150 message: String,
151 },
152 #[error("Permission error with {provider}'s API: {message}")]
153 PermissionError {
154 provider: LanguageModelProviderName,
155 message: String,
156 },
157 #[error("language model provider API endpoint not found")]
158 ApiEndpointNotFound { provider: LanguageModelProviderName },
159 #[error("I/O error reading response from {provider}'s API")]
160 ApiReadResponseError {
161 provider: LanguageModelProviderName,
162 #[source]
163 error: io::Error,
164 },
165 #[error("error serializing request to {provider} API")]
166 SerializeRequest {
167 provider: LanguageModelProviderName,
168 #[source]
169 error: serde_json::Error,
170 },
171 #[error("error building request body to {provider} API")]
172 BuildRequestBody {
173 provider: LanguageModelProviderName,
174 #[source]
175 error: http::Error,
176 },
177 #[error("error sending HTTP request to {provider} API")]
178 HttpSend {
179 provider: LanguageModelProviderName,
180 #[source]
181 error: anyhow::Error,
182 },
183 #[error("error deserializing {provider} API response")]
184 DeserializeResponse {
185 provider: LanguageModelProviderName,
186 #[source]
187 error: serde_json::Error,
188 },
189
190 #[error("stream from {provider} ended unexpectedly")]
191 StreamEndedUnexpectedly { provider: LanguageModelProviderName },
192
193 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
194 #[error(transparent)]
195 Other(#[from] anyhow::Error),
196}
197
198impl LanguageModelCompletionError {
199 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
200 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
201 let upstream_status = error_json
202 .get("upstream_status")
203 .and_then(|v| v.as_u64())
204 .and_then(|status| u16::try_from(status).ok())
205 .and_then(|status| StatusCode::from_u16(status).ok())?;
206 let inner_message = error_json
207 .get("message")
208 .and_then(|v| v.as_str())
209 .unwrap_or(message)
210 .to_string();
211 Some((upstream_status, inner_message))
212 }
213
214 pub fn from_cloud_failure(
215 upstream_provider: LanguageModelProviderName,
216 code: String,
217 message: String,
218 retry_after: Option<Duration>,
219 ) -> Self {
220 if let Some(tokens) = parse_prompt_too_long(&message) {
221 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
222 // to be reported. This is a temporary workaround to handle this in the case where the
223 // token limit has been exceeded.
224 Self::PromptTooLarge {
225 tokens: Some(tokens),
226 }
227 } else if code == "upstream_http_error" {
228 if let Some((upstream_status, inner_message)) =
229 Self::parse_upstream_error_json(&message)
230 {
231 return Self::from_http_status(
232 upstream_provider,
233 upstream_status,
234 inner_message,
235 retry_after,
236 );
237 }
238 anyhow!("completion request failed, code: {code}, message: {message}").into()
239 } else if let Some(status_code) = code
240 .strip_prefix("upstream_http_")
241 .and_then(|code| StatusCode::from_str(code).ok())
242 {
243 Self::from_http_status(upstream_provider, status_code, message, retry_after)
244 } else if let Some(status_code) = code
245 .strip_prefix("http_")
246 .and_then(|code| StatusCode::from_str(code).ok())
247 {
248 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
249 } else {
250 anyhow!("completion request failed, code: {code}, message: {message}").into()
251 }
252 }
253
254 pub fn from_http_status(
255 provider: LanguageModelProviderName,
256 status_code: StatusCode,
257 message: String,
258 retry_after: Option<Duration>,
259 ) -> Self {
260 match status_code {
261 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
262 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
263 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
264 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
265 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
266 tokens: parse_prompt_too_long(&message),
267 },
268 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
269 provider,
270 retry_after,
271 },
272 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
273 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
274 provider,
275 retry_after,
276 },
277 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
278 provider,
279 retry_after,
280 },
281 _ => Self::HttpResponseError {
282 provider,
283 status_code,
284 message,
285 },
286 }
287 }
288}
289
290#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
291#[serde(rename_all = "snake_case")]
292pub enum StopReason {
293 EndTurn,
294 MaxTokens,
295 ToolUse,
296 Refusal,
297}
298
299#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
300pub struct TokenUsage {
301 #[serde(default, skip_serializing_if = "is_default")]
302 pub input_tokens: u64,
303 #[serde(default, skip_serializing_if = "is_default")]
304 pub output_tokens: u64,
305 #[serde(default, skip_serializing_if = "is_default")]
306 pub cache_creation_input_tokens: u64,
307 #[serde(default, skip_serializing_if = "is_default")]
308 pub cache_read_input_tokens: u64,
309}
310
311impl TokenUsage {
312 pub fn total_tokens(&self) -> u64 {
313 self.input_tokens
314 + self.output_tokens
315 + self.cache_read_input_tokens
316 + self.cache_creation_input_tokens
317 }
318}
319
320impl Add<TokenUsage> for TokenUsage {
321 type Output = Self;
322
323 fn add(self, other: Self) -> Self {
324 Self {
325 input_tokens: self.input_tokens + other.input_tokens,
326 output_tokens: self.output_tokens + other.output_tokens,
327 cache_creation_input_tokens: self.cache_creation_input_tokens
328 + other.cache_creation_input_tokens,
329 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
330 }
331 }
332}
333
334impl Sub<TokenUsage> for TokenUsage {
335 type Output = Self;
336
337 fn sub(self, other: Self) -> Self {
338 Self {
339 input_tokens: self.input_tokens - other.input_tokens,
340 output_tokens: self.output_tokens - other.output_tokens,
341 cache_creation_input_tokens: self.cache_creation_input_tokens
342 - other.cache_creation_input_tokens,
343 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
344 }
345 }
346}
347
348#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
349pub struct LanguageModelToolUseId(Arc<str>);
350
351impl fmt::Display for LanguageModelToolUseId {
352 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353 write!(f, "{}", self.0)
354 }
355}
356
357impl<T> From<T> for LanguageModelToolUseId
358where
359 T: Into<Arc<str>>,
360{
361 fn from(value: T) -> Self {
362 Self(value.into())
363 }
364}
365
366#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
367pub struct LanguageModelToolUse {
368 pub id: LanguageModelToolUseId,
369 pub name: Arc<str>,
370 pub raw_input: String,
371 pub input: serde_json::Value,
372 pub is_input_complete: bool,
373 /// Thought signature the model sent us. Some models require that this
374 /// signature be preserved and sent back in conversation history for validation.
375 pub thought_signature: Option<String>,
376}
377
378pub struct LanguageModelTextStream {
379 pub message_id: Option<String>,
380 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
381 // Has complete token usage after the stream has finished
382 pub last_token_usage: Arc<Mutex<TokenUsage>>,
383}
384
385impl Default for LanguageModelTextStream {
386 fn default() -> Self {
387 Self {
388 message_id: None,
389 stream: Box::pin(futures::stream::empty()),
390 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
391 }
392 }
393}
394
395#[derive(Debug, Clone)]
396pub struct LanguageModelEffortLevel {
397 pub name: SharedString,
398 pub value: SharedString,
399 pub is_default: bool,
400}
401
402pub trait LanguageModel: Send + Sync {
403 fn id(&self) -> LanguageModelId;
404 fn name(&self) -> LanguageModelName;
405 fn provider_id(&self) -> LanguageModelProviderId;
406 fn provider_name(&self) -> LanguageModelProviderName;
407 fn upstream_provider_id(&self) -> LanguageModelProviderId {
408 self.provider_id()
409 }
410 fn upstream_provider_name(&self) -> LanguageModelProviderName {
411 self.provider_name()
412 }
413
414 /// Returns whether this model is the "latest", so we can highlight it in the UI.
415 fn is_latest(&self) -> bool {
416 false
417 }
418
419 fn telemetry_id(&self) -> String;
420
421 fn api_key(&self, _cx: &App) -> Option<String> {
422 None
423 }
424
425 /// Information about the cost of using this model, if available.
426 fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
427 None
428 }
429
430 /// Whether this model supports thinking.
431 fn supports_thinking(&self) -> bool {
432 false
433 }
434
435 fn supports_fast_mode(&self) -> bool {
436 false
437 }
438
439 /// Returns the list of supported effort levels that can be used when thinking.
440 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
441 Vec::new()
442 }
443
444 /// Returns the default effort level to use when thinking.
445 fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
446 self.supported_effort_levels()
447 .into_iter()
448 .find(|effort_level| effort_level.is_default)
449 }
450
451 /// Whether this model supports images
452 fn supports_images(&self) -> bool;
453
454 /// Whether this model supports tools.
455 fn supports_tools(&self) -> bool;
456
457 /// Whether this model supports choosing which tool to use.
458 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
459
460 /// Returns whether this model or provider supports streaming tool calls;
461 fn supports_streaming_tools(&self) -> bool {
462 false
463 }
464
465 /// Returns whether this model/provider reports accurate split input/output token counts.
466 /// When true, the UI may show separate input/output token indicators.
467 fn supports_split_token_display(&self) -> bool {
468 false
469 }
470
471 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
472 LanguageModelToolSchemaFormat::JsonSchema
473 }
474
475 fn max_token_count(&self) -> u64;
476 fn max_output_tokens(&self) -> Option<u64> {
477 None
478 }
479
480 fn count_tokens(
481 &self,
482 request: LanguageModelRequest,
483 cx: &App,
484 ) -> BoxFuture<'static, Result<u64>>;
485
486 fn stream_completion(
487 &self,
488 request: LanguageModelRequest,
489 cx: &AsyncApp,
490 ) -> BoxFuture<
491 'static,
492 Result<
493 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
494 LanguageModelCompletionError,
495 >,
496 >;
497
498 fn stream_completion_text(
499 &self,
500 request: LanguageModelRequest,
501 cx: &AsyncApp,
502 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
503 let future = self.stream_completion(request, cx);
504
505 async move {
506 let events = future.await?;
507 let mut events = events.fuse();
508 let mut message_id = None;
509 let mut first_item_text = None;
510 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
511
512 if let Some(first_event) = events.next().await {
513 match first_event {
514 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
515 message_id = Some(id);
516 }
517 Ok(LanguageModelCompletionEvent::Text(text)) => {
518 first_item_text = Some(text);
519 }
520 _ => (),
521 }
522 }
523
524 let stream = futures::stream::iter(first_item_text.map(Ok))
525 .chain(events.filter_map({
526 let last_token_usage = last_token_usage.clone();
527 move |result| {
528 let last_token_usage = last_token_usage.clone();
529 async move {
530 match result {
531 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
532 Ok(LanguageModelCompletionEvent::Started) => None,
533 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
534 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
535 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
536 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
537 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
538 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
539 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
540 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
541 ..
542 }) => None,
543 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
544 *last_token_usage.lock() = token_usage;
545 None
546 }
547 Err(err) => Some(Err(err)),
548 }
549 }
550 }
551 }))
552 .boxed();
553
554 Ok(LanguageModelTextStream {
555 message_id,
556 stream,
557 last_token_usage,
558 })
559 }
560 .boxed()
561 }
562
563 fn stream_completion_tool(
564 &self,
565 request: LanguageModelRequest,
566 cx: &AsyncApp,
567 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
568 let future = self.stream_completion(request, cx);
569
570 async move {
571 let events = future.await?;
572 let mut events = events.fuse();
573
574 // Iterate through events until we find a complete ToolUse
575 while let Some(event) = events.next().await {
576 match event {
577 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
578 if tool_use.is_input_complete =>
579 {
580 return Ok(tool_use);
581 }
582 Err(err) => {
583 return Err(err);
584 }
585 _ => {}
586 }
587 }
588
589 // Stream ended without a complete tool use
590 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
591 "Stream ended without receiving a complete tool use"
592 )))
593 }
594 .boxed()
595 }
596
597 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
598 None
599 }
600
601 #[cfg(any(test, feature = "test-support"))]
602 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
603 unimplemented!()
604 }
605}
606
607impl std::fmt::Debug for dyn LanguageModel {
608 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
609 f.debug_struct("<dyn LanguageModel>")
610 .field("id", &self.id())
611 .field("name", &self.name())
612 .field("provider_id", &self.provider_id())
613 .field("provider_name", &self.provider_name())
614 .field("upstream_provider_name", &self.upstream_provider_name())
615 .field("upstream_provider_id", &self.upstream_provider_id())
616 .field("upstream_provider_id", &self.upstream_provider_id())
617 .field("supports_streaming_tools", &self.supports_streaming_tools())
618 .finish()
619 }
620}
621
622/// An error that occurred when trying to authenticate the language model provider.
623#[derive(Debug, Error)]
624pub enum AuthenticateError {
625 #[error("connection refused")]
626 ConnectionRefused,
627 #[error("credentials not found")]
628 CredentialsNotFound,
629 #[error(transparent)]
630 Other(#[from] anyhow::Error),
631}
632
633/// Either a built-in icon name or a path to an external SVG.
634#[derive(Debug, Clone, PartialEq, Eq)]
635pub enum IconOrSvg {
636 /// A built-in icon from Zed's icon set.
637 Icon(IconName),
638 /// Path to a custom SVG icon file.
639 Svg(SharedString),
640}
641
642impl Default for IconOrSvg {
643 fn default() -> Self {
644 Self::Icon(IconName::ZedAssistant)
645 }
646}
647
648pub trait LanguageModelProvider: 'static {
649 fn id(&self) -> LanguageModelProviderId;
650 fn name(&self) -> LanguageModelProviderName;
651 fn icon(&self) -> IconOrSvg {
652 IconOrSvg::default()
653 }
654 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
655 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
656 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
657 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
658 Vec::new()
659 }
660 fn is_authenticated(&self, cx: &App) -> bool;
661 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
662 fn configuration_view(
663 &self,
664 target_agent: ConfigurationViewTargetAgent,
665 window: &mut Window,
666 cx: &mut App,
667 ) -> AnyView;
668 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
669}
670
671#[derive(Default, Clone, PartialEq, Eq)]
672pub enum ConfigurationViewTargetAgent {
673 #[default]
674 ZedAgent,
675 Other(SharedString),
676}
677
678pub trait LanguageModelProviderState: 'static {
679 type ObservableEntity;
680
681 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
682
683 fn subscribe<T: 'static>(
684 &self,
685 cx: &mut gpui::Context<T>,
686 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
687 ) -> Option<gpui::Subscription> {
688 let entity = self.observable_entity()?;
689 Some(cx.observe(&entity, move |this, _, cx| {
690 callback(this, cx);
691 }))
692 }
693}
694
695#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
696pub struct LanguageModelId(pub SharedString);
697
698#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
699pub struct LanguageModelName(pub SharedString);
700
701#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
702pub struct LanguageModelProviderId(pub SharedString);
703
704#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
705pub struct LanguageModelProviderName(pub SharedString);
706
707#[derive(Clone, Debug, PartialEq)]
708pub enum LanguageModelCostInfo {
709 /// Cost per 1,000 input and output tokens
710 TokenCost {
711 input_token_cost_per_1m: f64,
712 output_token_cost_per_1m: f64,
713 },
714 /// Cost per request
715 RequestCost { cost_per_request: f64 },
716}
717
718impl LanguageModelCostInfo {
719 pub fn to_shared_string(&self) -> SharedString {
720 match self {
721 LanguageModelCostInfo::RequestCost { cost_per_request } => {
722 let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
723 SharedString::from(cost_str)
724 }
725 LanguageModelCostInfo::TokenCost {
726 input_token_cost_per_1m,
727 output_token_cost_per_1m,
728 } => {
729 let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
730 let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
731 SharedString::from(format!("{}$/{}$", input_cost, output_cost))
732 }
733 }
734 }
735
736 fn cost_value_to_string(cost: &f64) -> SharedString {
737 if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
738 SharedString::from(format!("{:.0}", cost))
739 } else {
740 SharedString::from(format!("{:.2}", cost))
741 }
742 }
743}
744
745impl LanguageModelProviderId {
746 pub const fn new(id: &'static str) -> Self {
747 Self(SharedString::new_static(id))
748 }
749}
750
751impl LanguageModelProviderName {
752 pub const fn new(id: &'static str) -> Self {
753 Self(SharedString::new_static(id))
754 }
755}
756
757impl fmt::Display for LanguageModelProviderId {
758 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
759 write!(f, "{}", self.0)
760 }
761}
762
763impl fmt::Display for LanguageModelProviderName {
764 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
765 write!(f, "{}", self.0)
766 }
767}
768
769impl From<String> for LanguageModelId {
770 fn from(value: String) -> Self {
771 Self(SharedString::from(value))
772 }
773}
774
775impl From<String> for LanguageModelName {
776 fn from(value: String) -> Self {
777 Self(SharedString::from(value))
778 }
779}
780
781impl From<String> for LanguageModelProviderId {
782 fn from(value: String) -> Self {
783 Self(SharedString::from(value))
784 }
785}
786
787impl From<String> for LanguageModelProviderName {
788 fn from(value: String) -> Self {
789 Self(SharedString::from(value))
790 }
791}
792
793impl From<Arc<str>> for LanguageModelProviderId {
794 fn from(value: Arc<str>) -> Self {
795 Self(SharedString::from(value))
796 }
797}
798
799impl From<Arc<str>> for LanguageModelProviderName {
800 fn from(value: Arc<str>) -> Self {
801 Self(SharedString::from(value))
802 }
803}
804
805#[cfg(test)]
806mod tests {
807 use super::*;
808
809 #[test]
810 fn test_from_cloud_failure_with_upstream_http_error() {
811 let error = LanguageModelCompletionError::from_cloud_failure(
812 String::from("anthropic").into(),
813 "upstream_http_error".to_string(),
814 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
815 None,
816 );
817
818 match error {
819 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
820 assert_eq!(provider.0, "anthropic");
821 }
822 _ => panic!(
823 "Expected ServerOverloaded error for 503 status, got: {:?}",
824 error
825 ),
826 }
827
828 let error = LanguageModelCompletionError::from_cloud_failure(
829 String::from("anthropic").into(),
830 "upstream_http_error".to_string(),
831 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
832 None,
833 );
834
835 match error {
836 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
837 assert_eq!(provider.0, "anthropic");
838 assert_eq!(message, "Internal server error");
839 }
840 _ => panic!(
841 "Expected ApiInternalServerError for 500 status, got: {:?}",
842 error
843 ),
844 }
845 }
846
847 #[test]
848 fn test_from_cloud_failure_with_standard_format() {
849 let error = LanguageModelCompletionError::from_cloud_failure(
850 String::from("anthropic").into(),
851 "upstream_http_503".to_string(),
852 "Service unavailable".to_string(),
853 None,
854 );
855
856 match error {
857 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
858 assert_eq!(provider.0, "anthropic");
859 }
860 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
861 }
862 }
863
864 #[test]
865 fn test_upstream_http_error_connection_timeout() {
866 let error = LanguageModelCompletionError::from_cloud_failure(
867 String::from("anthropic").into(),
868 "upstream_http_error".to_string(),
869 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
870 None,
871 );
872
873 match error {
874 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
875 assert_eq!(provider.0, "anthropic");
876 }
877 _ => panic!(
878 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
879 error
880 ),
881 }
882
883 let error = LanguageModelCompletionError::from_cloud_failure(
884 String::from("anthropic").into(),
885 "upstream_http_error".to_string(),
886 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
887 None,
888 );
889
890 match error {
891 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
892 assert_eq!(provider.0, "anthropic");
893 assert_eq!(
894 message,
895 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
896 );
897 }
898 _ => panic!(
899 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
900 error
901 ),
902 }
903 }
904
905 #[test]
906 fn test_language_model_tool_use_serializes_with_signature() {
907 use serde_json::json;
908
909 let tool_use = LanguageModelToolUse {
910 id: LanguageModelToolUseId::from("test_id"),
911 name: "test_tool".into(),
912 raw_input: json!({"arg": "value"}).to_string(),
913 input: json!({"arg": "value"}),
914 is_input_complete: true,
915 thought_signature: Some("test_signature".to_string()),
916 };
917
918 let serialized = serde_json::to_value(&tool_use).unwrap();
919
920 assert_eq!(serialized["id"], "test_id");
921 assert_eq!(serialized["name"], "test_tool");
922 assert_eq!(serialized["thought_signature"], "test_signature");
923 }
924
925 #[test]
926 fn test_language_model_tool_use_deserializes_with_missing_signature() {
927 use serde_json::json;
928
929 let json = json!({
930 "id": "test_id",
931 "name": "test_tool",
932 "raw_input": "{\"arg\":\"value\"}",
933 "input": {"arg": "value"},
934 "is_input_complete": true
935 });
936
937 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
938
939 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
940 assert_eq!(tool_use.name.as_ref(), "test_tool");
941 assert_eq!(tool_use.thought_signature, None);
942 }
943
944 #[test]
945 fn test_language_model_tool_use_round_trip_with_signature() {
946 use serde_json::json;
947
948 let original = LanguageModelToolUse {
949 id: LanguageModelToolUseId::from("round_trip_id"),
950 name: "round_trip_tool".into(),
951 raw_input: json!({"key": "value"}).to_string(),
952 input: json!({"key": "value"}),
953 is_input_complete: true,
954 thought_signature: Some("round_trip_sig".to_string()),
955 };
956
957 let serialized = serde_json::to_value(&original).unwrap();
958 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
959
960 assert_eq!(deserialized.id, original.id);
961 assert_eq!(deserialized.name, original.name);
962 assert_eq!(deserialized.thought_signature, original.thought_signature);
963 }
964
965 #[test]
966 fn test_language_model_tool_use_round_trip_without_signature() {
967 use serde_json::json;
968
969 let original = LanguageModelToolUse {
970 id: LanguageModelToolUseId::from("no_sig_id"),
971 name: "no_sig_tool".into(),
972 raw_input: json!({"arg": "value"}).to_string(),
973 input: json!({"arg": "value"}),
974 is_input_complete: true,
975 thought_signature: None,
976 };
977
978 let serialized = serde_json::to_value(&original).unwrap();
979 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
980
981 assert_eq!(deserialized.id, original.id);
982 assert_eq!(deserialized.name, original.name);
983 assert_eq!(deserialized.thought_signature, None);
984 }
985}