1mod api_key;
2mod model;
3pub mod provider;
4mod rate_limiter;
5mod registry;
6mod request;
7mod role;
8pub mod tool_schema;
9
10#[cfg(any(test, feature = "test-support"))]
11pub mod fake_provider;
12
13use anyhow::{Result, anyhow};
14use client::Client;
15use client::UserStore;
16use cloud_llm_client::CompletionRequestStatus;
17use futures::FutureExt;
18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
19use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
20use http_client::{StatusCode, http};
21use icons::IconName;
22use parking_lot::Mutex;
23use provider::parse_prompt_too_long;
24use serde::{Deserialize, Serialize};
25pub use settings::LanguageModelCacheConfiguration;
26use std::ops::{Add, Sub};
27use std::str::FromStr;
28use std::sync::Arc;
29use std::time::Duration;
30use std::{fmt, io};
31use thiserror::Error;
32use util::serde::is_default;
33
34pub use crate::api_key::{ApiKey, ApiKeyState};
35pub use crate::model::*;
36pub use crate::rate_limiter::*;
37pub use crate::registry::*;
38pub use crate::request::*;
39pub use crate::role::*;
40pub use crate::tool_schema::LanguageModelToolSchemaFormat;
41pub use zed_env_vars::{EnvVar, env_var};
42
43pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
44 init_settings(cx);
45 RefreshLlmTokenListener::register(client, user_store, cx);
46}
47
48pub fn init_settings(cx: &mut App) {
49 registry::init(cx);
50}
51
52/// A completion event from a language model.
53#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
54pub enum LanguageModelCompletionEvent {
55 Queued {
56 position: usize,
57 },
58 Started,
59 Stop(StopReason),
60 Text(String),
61 Thinking {
62 text: String,
63 signature: Option<String>,
64 },
65 RedactedThinking {
66 data: String,
67 },
68 ToolUse(LanguageModelToolUse),
69 ToolUseJsonParseError {
70 id: LanguageModelToolUseId,
71 tool_name: Arc<str>,
72 raw_input: Arc<str>,
73 json_parse_error: String,
74 },
75 StartMessage {
76 message_id: String,
77 },
78 ReasoningDetails(serde_json::Value),
79 UsageUpdate(TokenUsage),
80}
81
82impl LanguageModelCompletionEvent {
83 pub fn from_completion_request_status(
84 status: CompletionRequestStatus,
85 upstream_provider: LanguageModelProviderName,
86 ) -> Result<Option<Self>, LanguageModelCompletionError> {
87 match status {
88 CompletionRequestStatus::Queued { position } => {
89 Ok(Some(LanguageModelCompletionEvent::Queued { position }))
90 }
91 CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
92 CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
93 CompletionRequestStatus::Failed {
94 code,
95 message,
96 request_id: _,
97 retry_after,
98 } => Err(LanguageModelCompletionError::from_cloud_failure(
99 upstream_provider,
100 code,
101 message,
102 retry_after.map(Duration::from_secs_f64),
103 )),
104 }
105 }
106}
107
108#[derive(Error, Debug)]
109pub enum LanguageModelCompletionError {
110 #[error("prompt too large for context window")]
111 PromptTooLarge { tokens: Option<u64> },
112 #[error("missing {provider} API key")]
113 NoApiKey { provider: LanguageModelProviderName },
114 #[error("{provider}'s API rate limit exceeded")]
115 RateLimitExceeded {
116 provider: LanguageModelProviderName,
117 retry_after: Option<Duration>,
118 },
119 #[error("{provider}'s API servers are overloaded right now")]
120 ServerOverloaded {
121 provider: LanguageModelProviderName,
122 retry_after: Option<Duration>,
123 },
124 #[error("{provider}'s API server reported an internal server error: {message}")]
125 ApiInternalServerError {
126 provider: LanguageModelProviderName,
127 message: String,
128 },
129 #[error("{message}")]
130 UpstreamProviderError {
131 message: String,
132 status: StatusCode,
133 retry_after: Option<Duration>,
134 },
135 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
136 HttpResponseError {
137 provider: LanguageModelProviderName,
138 status_code: StatusCode,
139 message: String,
140 },
141
142 // Client errors
143 #[error("invalid request format to {provider}'s API: {message}")]
144 BadRequestFormat {
145 provider: LanguageModelProviderName,
146 message: String,
147 },
148 #[error("authentication error with {provider}'s API: {message}")]
149 AuthenticationError {
150 provider: LanguageModelProviderName,
151 message: String,
152 },
153 #[error("Permission error with {provider}'s API: {message}")]
154 PermissionError {
155 provider: LanguageModelProviderName,
156 message: String,
157 },
158 #[error("language model provider API endpoint not found")]
159 ApiEndpointNotFound { provider: LanguageModelProviderName },
160 #[error("I/O error reading response from {provider}'s API")]
161 ApiReadResponseError {
162 provider: LanguageModelProviderName,
163 #[source]
164 error: io::Error,
165 },
166 #[error("error serializing request to {provider} API")]
167 SerializeRequest {
168 provider: LanguageModelProviderName,
169 #[source]
170 error: serde_json::Error,
171 },
172 #[error("error building request body to {provider} API")]
173 BuildRequestBody {
174 provider: LanguageModelProviderName,
175 #[source]
176 error: http::Error,
177 },
178 #[error("error sending HTTP request to {provider} API")]
179 HttpSend {
180 provider: LanguageModelProviderName,
181 #[source]
182 error: anyhow::Error,
183 },
184 #[error("error deserializing {provider} API response")]
185 DeserializeResponse {
186 provider: LanguageModelProviderName,
187 #[source]
188 error: serde_json::Error,
189 },
190
191 #[error("stream from {provider} ended unexpectedly")]
192 StreamEndedUnexpectedly { provider: LanguageModelProviderName },
193
194 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
195 #[error(transparent)]
196 Other(#[from] anyhow::Error),
197}
198
199impl LanguageModelCompletionError {
200 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
201 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
202 let upstream_status = error_json
203 .get("upstream_status")
204 .and_then(|v| v.as_u64())
205 .and_then(|status| u16::try_from(status).ok())
206 .and_then(|status| StatusCode::from_u16(status).ok())?;
207 let inner_message = error_json
208 .get("message")
209 .and_then(|v| v.as_str())
210 .unwrap_or(message)
211 .to_string();
212 Some((upstream_status, inner_message))
213 }
214
215 pub fn from_cloud_failure(
216 upstream_provider: LanguageModelProviderName,
217 code: String,
218 message: String,
219 retry_after: Option<Duration>,
220 ) -> Self {
221 if let Some(tokens) = parse_prompt_too_long(&message) {
222 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
223 // to be reported. This is a temporary workaround to handle this in the case where the
224 // token limit has been exceeded.
225 Self::PromptTooLarge {
226 tokens: Some(tokens),
227 }
228 } else if code == "upstream_http_error" {
229 if let Some((upstream_status, inner_message)) =
230 Self::parse_upstream_error_json(&message)
231 {
232 return Self::from_http_status(
233 upstream_provider,
234 upstream_status,
235 inner_message,
236 retry_after,
237 );
238 }
239 anyhow!("completion request failed, code: {code}, message: {message}").into()
240 } else if let Some(status_code) = code
241 .strip_prefix("upstream_http_")
242 .and_then(|code| StatusCode::from_str(code).ok())
243 {
244 Self::from_http_status(upstream_provider, status_code, message, retry_after)
245 } else if let Some(status_code) = code
246 .strip_prefix("http_")
247 .and_then(|code| StatusCode::from_str(code).ok())
248 {
249 Self::from_http_status(
250 provider::ZED_CLOUD_PROVIDER_NAME,
251 status_code,
252 message,
253 retry_after,
254 )
255 } else {
256 anyhow!("completion request failed, code: {code}, message: {message}").into()
257 }
258 }
259
260 pub fn from_http_status(
261 provider: LanguageModelProviderName,
262 status_code: StatusCode,
263 message: String,
264 retry_after: Option<Duration>,
265 ) -> Self {
266 match status_code {
267 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
268 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
269 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
270 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
271 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
272 tokens: parse_prompt_too_long(&message),
273 },
274 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
275 provider,
276 retry_after,
277 },
278 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
279 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
280 provider,
281 retry_after,
282 },
283 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
284 provider,
285 retry_after,
286 },
287 _ => Self::HttpResponseError {
288 provider,
289 status_code,
290 message,
291 },
292 }
293 }
294}
295
296#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
297#[serde(rename_all = "snake_case")]
298pub enum StopReason {
299 EndTurn,
300 MaxTokens,
301 ToolUse,
302 Refusal,
303}
304
305#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
306pub struct TokenUsage {
307 #[serde(default, skip_serializing_if = "is_default")]
308 pub input_tokens: u64,
309 #[serde(default, skip_serializing_if = "is_default")]
310 pub output_tokens: u64,
311 #[serde(default, skip_serializing_if = "is_default")]
312 pub cache_creation_input_tokens: u64,
313 #[serde(default, skip_serializing_if = "is_default")]
314 pub cache_read_input_tokens: u64,
315}
316
317impl TokenUsage {
318 pub fn total_tokens(&self) -> u64 {
319 self.input_tokens
320 + self.output_tokens
321 + self.cache_read_input_tokens
322 + self.cache_creation_input_tokens
323 }
324}
325
326impl Add<TokenUsage> for TokenUsage {
327 type Output = Self;
328
329 fn add(self, other: Self) -> Self {
330 Self {
331 input_tokens: self.input_tokens + other.input_tokens,
332 output_tokens: self.output_tokens + other.output_tokens,
333 cache_creation_input_tokens: self.cache_creation_input_tokens
334 + other.cache_creation_input_tokens,
335 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
336 }
337 }
338}
339
340impl Sub<TokenUsage> for TokenUsage {
341 type Output = Self;
342
343 fn sub(self, other: Self) -> Self {
344 Self {
345 input_tokens: self.input_tokens - other.input_tokens,
346 output_tokens: self.output_tokens - other.output_tokens,
347 cache_creation_input_tokens: self.cache_creation_input_tokens
348 - other.cache_creation_input_tokens,
349 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
350 }
351 }
352}
353
354#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
355pub struct LanguageModelToolUseId(Arc<str>);
356
357impl fmt::Display for LanguageModelToolUseId {
358 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359 write!(f, "{}", self.0)
360 }
361}
362
363impl<T> From<T> for LanguageModelToolUseId
364where
365 T: Into<Arc<str>>,
366{
367 fn from(value: T) -> Self {
368 Self(value.into())
369 }
370}
371
372#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
373pub struct LanguageModelToolUse {
374 pub id: LanguageModelToolUseId,
375 pub name: Arc<str>,
376 pub raw_input: String,
377 pub input: serde_json::Value,
378 pub is_input_complete: bool,
379 /// Thought signature the model sent us. Some models require that this
380 /// signature be preserved and sent back in conversation history for validation.
381 pub thought_signature: Option<String>,
382}
383
384pub struct LanguageModelTextStream {
385 pub message_id: Option<String>,
386 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
387 // Has complete token usage after the stream has finished
388 pub last_token_usage: Arc<Mutex<TokenUsage>>,
389}
390
391impl Default for LanguageModelTextStream {
392 fn default() -> Self {
393 Self {
394 message_id: None,
395 stream: Box::pin(futures::stream::empty()),
396 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
397 }
398 }
399}
400
401#[derive(Debug, Clone)]
402pub struct LanguageModelEffortLevel {
403 pub name: SharedString,
404 pub value: SharedString,
405 pub is_default: bool,
406}
407
408pub trait LanguageModel: Send + Sync {
409 fn id(&self) -> LanguageModelId;
410 fn name(&self) -> LanguageModelName;
411 fn provider_id(&self) -> LanguageModelProviderId;
412 fn provider_name(&self) -> LanguageModelProviderName;
413 fn upstream_provider_id(&self) -> LanguageModelProviderId {
414 self.provider_id()
415 }
416 fn upstream_provider_name(&self) -> LanguageModelProviderName {
417 self.provider_name()
418 }
419
420 /// Returns whether this model is the "latest", so we can highlight it in the UI.
421 fn is_latest(&self) -> bool {
422 false
423 }
424
425 fn telemetry_id(&self) -> String;
426
427 fn api_key(&self, _cx: &App) -> Option<String> {
428 None
429 }
430
431 /// Information about the cost of using this model, if available.
432 fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
433 None
434 }
435
436 /// Whether this model supports thinking.
437 fn supports_thinking(&self) -> bool {
438 false
439 }
440
441 fn supports_fast_mode(&self) -> bool {
442 false
443 }
444
445 /// Returns the list of supported effort levels that can be used when thinking.
446 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
447 Vec::new()
448 }
449
450 /// Returns the default effort level to use when thinking.
451 fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
452 self.supported_effort_levels()
453 .into_iter()
454 .find(|effort_level| effort_level.is_default)
455 }
456
457 /// Whether this model supports images
458 fn supports_images(&self) -> bool;
459
460 /// Whether this model supports tools.
461 fn supports_tools(&self) -> bool;
462
463 /// Whether this model supports choosing which tool to use.
464 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
465
466 /// Returns whether this model or provider supports streaming tool calls;
467 fn supports_streaming_tools(&self) -> bool {
468 false
469 }
470
471 /// Returns whether this model/provider reports accurate split input/output token counts.
472 /// When true, the UI may show separate input/output token indicators.
473 fn supports_split_token_display(&self) -> bool {
474 false
475 }
476
477 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
478 LanguageModelToolSchemaFormat::JsonSchema
479 }
480
481 fn max_token_count(&self) -> u64;
482 fn max_output_tokens(&self) -> Option<u64> {
483 None
484 }
485
486 fn count_tokens(
487 &self,
488 request: LanguageModelRequest,
489 cx: &App,
490 ) -> BoxFuture<'static, Result<u64>>;
491
492 fn stream_completion(
493 &self,
494 request: LanguageModelRequest,
495 cx: &AsyncApp,
496 ) -> BoxFuture<
497 'static,
498 Result<
499 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
500 LanguageModelCompletionError,
501 >,
502 >;
503
504 fn stream_completion_text(
505 &self,
506 request: LanguageModelRequest,
507 cx: &AsyncApp,
508 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
509 let future = self.stream_completion(request, cx);
510
511 async move {
512 let events = future.await?;
513 let mut events = events.fuse();
514 let mut message_id = None;
515 let mut first_item_text = None;
516 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
517
518 if let Some(first_event) = events.next().await {
519 match first_event {
520 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
521 message_id = Some(id);
522 }
523 Ok(LanguageModelCompletionEvent::Text(text)) => {
524 first_item_text = Some(text);
525 }
526 _ => (),
527 }
528 }
529
530 let stream = futures::stream::iter(first_item_text.map(Ok))
531 .chain(events.filter_map({
532 let last_token_usage = last_token_usage.clone();
533 move |result| {
534 let last_token_usage = last_token_usage.clone();
535 async move {
536 match result {
537 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
538 Ok(LanguageModelCompletionEvent::Started) => None,
539 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
540 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
541 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
542 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
543 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
544 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
545 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
546 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
547 ..
548 }) => None,
549 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
550 *last_token_usage.lock() = token_usage;
551 None
552 }
553 Err(err) => Some(Err(err)),
554 }
555 }
556 }
557 }))
558 .boxed();
559
560 Ok(LanguageModelTextStream {
561 message_id,
562 stream,
563 last_token_usage,
564 })
565 }
566 .boxed()
567 }
568
569 fn stream_completion_tool(
570 &self,
571 request: LanguageModelRequest,
572 cx: &AsyncApp,
573 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
574 let future = self.stream_completion(request, cx);
575
576 async move {
577 let events = future.await?;
578 let mut events = events.fuse();
579
580 // Iterate through events until we find a complete ToolUse
581 while let Some(event) = events.next().await {
582 match event {
583 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
584 if tool_use.is_input_complete =>
585 {
586 return Ok(tool_use);
587 }
588 Err(err) => {
589 return Err(err);
590 }
591 _ => {}
592 }
593 }
594
595 // Stream ended without a complete tool use
596 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
597 "Stream ended without receiving a complete tool use"
598 )))
599 }
600 .boxed()
601 }
602
603 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
604 None
605 }
606
607 #[cfg(any(test, feature = "test-support"))]
608 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
609 unimplemented!()
610 }
611}
612
613impl std::fmt::Debug for dyn LanguageModel {
614 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
615 f.debug_struct("<dyn LanguageModel>")
616 .field("id", &self.id())
617 .field("name", &self.name())
618 .field("provider_id", &self.provider_id())
619 .field("provider_name", &self.provider_name())
620 .field("upstream_provider_name", &self.upstream_provider_name())
621 .field("upstream_provider_id", &self.upstream_provider_id())
622 .field("upstream_provider_id", &self.upstream_provider_id())
623 .field("supports_streaming_tools", &self.supports_streaming_tools())
624 .finish()
625 }
626}
627
628/// An error that occurred when trying to authenticate the language model provider.
629#[derive(Debug, Error)]
630pub enum AuthenticateError {
631 #[error("connection refused")]
632 ConnectionRefused,
633 #[error("credentials not found")]
634 CredentialsNotFound,
635 #[error(transparent)]
636 Other(#[from] anyhow::Error),
637}
638
639/// Either a built-in icon name or a path to an external SVG.
640#[derive(Debug, Clone, PartialEq, Eq)]
641pub enum IconOrSvg {
642 /// A built-in icon from Zed's icon set.
643 Icon(IconName),
644 /// Path to a custom SVG icon file.
645 Svg(SharedString),
646}
647
648impl Default for IconOrSvg {
649 fn default() -> Self {
650 Self::Icon(IconName::ZedAssistant)
651 }
652}
653
654pub trait LanguageModelProvider: 'static {
655 fn id(&self) -> LanguageModelProviderId;
656 fn name(&self) -> LanguageModelProviderName;
657 fn icon(&self) -> IconOrSvg {
658 IconOrSvg::default()
659 }
660 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
661 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
662 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
663 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
664 Vec::new()
665 }
666 fn is_authenticated(&self, cx: &App) -> bool;
667 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
668 fn configuration_view(
669 &self,
670 target_agent: ConfigurationViewTargetAgent,
671 window: &mut Window,
672 cx: &mut App,
673 ) -> AnyView;
674 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
675}
676
677#[derive(Default, Clone, PartialEq, Eq)]
678pub enum ConfigurationViewTargetAgent {
679 #[default]
680 ZedAgent,
681 Other(SharedString),
682}
683
684pub trait LanguageModelProviderState: 'static {
685 type ObservableEntity;
686
687 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
688
689 fn subscribe<T: 'static>(
690 &self,
691 cx: &mut gpui::Context<T>,
692 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
693 ) -> Option<gpui::Subscription> {
694 let entity = self.observable_entity()?;
695 Some(cx.observe(&entity, move |this, _, cx| {
696 callback(this, cx);
697 }))
698 }
699}
700
701#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
702pub struct LanguageModelId(pub SharedString);
703
704#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
705pub struct LanguageModelName(pub SharedString);
706
707#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
708pub struct LanguageModelProviderId(pub SharedString);
709
710#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
711pub struct LanguageModelProviderName(pub SharedString);
712
713#[derive(Clone, Debug, PartialEq)]
714pub enum LanguageModelCostInfo {
715 /// Cost per 1,000 input and output tokens
716 TokenCost {
717 input_token_cost_per_1m: f64,
718 output_token_cost_per_1m: f64,
719 },
720 /// Cost per request
721 RequestCost { cost_per_request: f64 },
722}
723
724impl LanguageModelCostInfo {
725 pub fn to_shared_string(&self) -> SharedString {
726 match self {
727 LanguageModelCostInfo::RequestCost { cost_per_request } => {
728 let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
729 SharedString::from(cost_str)
730 }
731 LanguageModelCostInfo::TokenCost {
732 input_token_cost_per_1m,
733 output_token_cost_per_1m,
734 } => {
735 let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
736 let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
737 SharedString::from(format!("{}$/{}$", input_cost, output_cost))
738 }
739 }
740 }
741
742 fn cost_value_to_string(cost: &f64) -> SharedString {
743 if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
744 SharedString::from(format!("{:.0}", cost))
745 } else {
746 SharedString::from(format!("{:.2}", cost))
747 }
748 }
749}
750
751impl LanguageModelProviderId {
752 pub const fn new(id: &'static str) -> Self {
753 Self(SharedString::new_static(id))
754 }
755}
756
757impl LanguageModelProviderName {
758 pub const fn new(id: &'static str) -> Self {
759 Self(SharedString::new_static(id))
760 }
761}
762
763impl fmt::Display for LanguageModelProviderId {
764 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
765 write!(f, "{}", self.0)
766 }
767}
768
769impl fmt::Display for LanguageModelProviderName {
770 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
771 write!(f, "{}", self.0)
772 }
773}
774
775impl From<String> for LanguageModelId {
776 fn from(value: String) -> Self {
777 Self(SharedString::from(value))
778 }
779}
780
781impl From<String> for LanguageModelName {
782 fn from(value: String) -> Self {
783 Self(SharedString::from(value))
784 }
785}
786
787impl From<String> for LanguageModelProviderId {
788 fn from(value: String) -> Self {
789 Self(SharedString::from(value))
790 }
791}
792
793impl From<String> for LanguageModelProviderName {
794 fn from(value: String) -> Self {
795 Self(SharedString::from(value))
796 }
797}
798
799impl From<Arc<str>> for LanguageModelProviderId {
800 fn from(value: Arc<str>) -> Self {
801 Self(SharedString::from(value))
802 }
803}
804
805impl From<Arc<str>> for LanguageModelProviderName {
806 fn from(value: Arc<str>) -> Self {
807 Self(SharedString::from(value))
808 }
809}
810
811#[cfg(test)]
812mod tests {
813 use super::*;
814
815 #[test]
816 fn test_from_cloud_failure_with_upstream_http_error() {
817 let error = LanguageModelCompletionError::from_cloud_failure(
818 String::from("anthropic").into(),
819 "upstream_http_error".to_string(),
820 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
821 None,
822 );
823
824 match error {
825 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
826 assert_eq!(provider.0, "anthropic");
827 }
828 _ => panic!(
829 "Expected ServerOverloaded error for 503 status, got: {:?}",
830 error
831 ),
832 }
833
834 let error = LanguageModelCompletionError::from_cloud_failure(
835 String::from("anthropic").into(),
836 "upstream_http_error".to_string(),
837 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
838 None,
839 );
840
841 match error {
842 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
843 assert_eq!(provider.0, "anthropic");
844 assert_eq!(message, "Internal server error");
845 }
846 _ => panic!(
847 "Expected ApiInternalServerError for 500 status, got: {:?}",
848 error
849 ),
850 }
851 }
852
853 #[test]
854 fn test_from_cloud_failure_with_standard_format() {
855 let error = LanguageModelCompletionError::from_cloud_failure(
856 String::from("anthropic").into(),
857 "upstream_http_503".to_string(),
858 "Service unavailable".to_string(),
859 None,
860 );
861
862 match error {
863 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
864 assert_eq!(provider.0, "anthropic");
865 }
866 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
867 }
868 }
869
870 #[test]
871 fn test_upstream_http_error_connection_timeout() {
872 let error = LanguageModelCompletionError::from_cloud_failure(
873 String::from("anthropic").into(),
874 "upstream_http_error".to_string(),
875 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
876 None,
877 );
878
879 match error {
880 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
881 assert_eq!(provider.0, "anthropic");
882 }
883 _ => panic!(
884 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
885 error
886 ),
887 }
888
889 let error = LanguageModelCompletionError::from_cloud_failure(
890 String::from("anthropic").into(),
891 "upstream_http_error".to_string(),
892 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
893 None,
894 );
895
896 match error {
897 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
898 assert_eq!(provider.0, "anthropic");
899 assert_eq!(
900 message,
901 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
902 );
903 }
904 _ => panic!(
905 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
906 error
907 ),
908 }
909 }
910
911 #[test]
912 fn test_language_model_tool_use_serializes_with_signature() {
913 use serde_json::json;
914
915 let tool_use = LanguageModelToolUse {
916 id: LanguageModelToolUseId::from("test_id"),
917 name: "test_tool".into(),
918 raw_input: json!({"arg": "value"}).to_string(),
919 input: json!({"arg": "value"}),
920 is_input_complete: true,
921 thought_signature: Some("test_signature".to_string()),
922 };
923
924 let serialized = serde_json::to_value(&tool_use).unwrap();
925
926 assert_eq!(serialized["id"], "test_id");
927 assert_eq!(serialized["name"], "test_tool");
928 assert_eq!(serialized["thought_signature"], "test_signature");
929 }
930
931 #[test]
932 fn test_language_model_tool_use_deserializes_with_missing_signature() {
933 use serde_json::json;
934
935 let json = json!({
936 "id": "test_id",
937 "name": "test_tool",
938 "raw_input": "{\"arg\":\"value\"}",
939 "input": {"arg": "value"},
940 "is_input_complete": true
941 });
942
943 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
944
945 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
946 assert_eq!(tool_use.name.as_ref(), "test_tool");
947 assert_eq!(tool_use.thought_signature, None);
948 }
949
950 #[test]
951 fn test_language_model_tool_use_round_trip_with_signature() {
952 use serde_json::json;
953
954 let original = LanguageModelToolUse {
955 id: LanguageModelToolUseId::from("round_trip_id"),
956 name: "round_trip_tool".into(),
957 raw_input: json!({"key": "value"}).to_string(),
958 input: json!({"key": "value"}),
959 is_input_complete: true,
960 thought_signature: Some("round_trip_sig".to_string()),
961 };
962
963 let serialized = serde_json::to_value(&original).unwrap();
964 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
965
966 assert_eq!(deserialized.id, original.id);
967 assert_eq!(deserialized.name, original.name);
968 assert_eq!(deserialized.thought_signature, original.thought_signature);
969 }
970
971 #[test]
972 fn test_language_model_tool_use_round_trip_without_signature() {
973 use serde_json::json;
974
975 let original = LanguageModelToolUse {
976 id: LanguageModelToolUseId::from("no_sig_id"),
977 name: "no_sig_tool".into(),
978 raw_input: json!({"arg": "value"}).to_string(),
979 input: json!({"arg": "value"}),
980 is_input_complete: true,
981 thought_signature: None,
982 };
983
984 let serialized = serde_json::to_value(&original).unwrap();
985 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
986
987 assert_eq!(deserialized.id, original.id);
988 assert_eq!(deserialized.name, original.name);
989 assert_eq!(deserialized.thought_signature, None);
990 }
991}