1mod api_key;
2mod model;
3mod provider;
4mod rate_limiter;
5mod registry;
6mod request;
7mod role;
8pub mod tool_schema;
9
10#[cfg(any(test, feature = "test-support"))]
11pub mod fake_provider;
12
13use anyhow::{Result, anyhow};
14use client::Client;
15use client::UserStore;
16use cloud_llm_client::CompletionRequestStatus;
17use futures::FutureExt;
18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
19use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
20use http_client::{StatusCode, http};
21use icons::IconName;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24use std::ops::{Add, Sub};
25use std::str::FromStr;
26use std::sync::Arc;
27use std::time::Duration;
28use std::{fmt, io};
29use thiserror::Error;
30use util::serde::is_default;
31
32pub use crate::api_key::{ApiKey, ApiKeyState};
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::tool_schema::LanguageModelToolSchemaFormat;
39pub use provider::*;
40pub use zed_env_vars::{EnvVar, env_var};
41
42pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
43 init_settings(cx);
44 RefreshLlmTokenListener::register(client, user_store, cx);
45}
46
47pub fn init_settings(cx: &mut App) {
48 registry::init(cx);
49}
50
51#[derive(Clone, Debug)]
52pub struct LanguageModelCacheConfiguration {
53 pub max_cache_anchors: usize,
54 pub should_speculate: bool,
55 pub min_total_token: u64,
56}
57
58/// A completion event from a language model.
59#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
60pub enum LanguageModelCompletionEvent {
61 Queued {
62 position: usize,
63 },
64 Started,
65 Stop(StopReason),
66 Text(String),
67 Thinking {
68 text: String,
69 signature: Option<String>,
70 },
71 RedactedThinking {
72 data: String,
73 },
74 ToolUse(LanguageModelToolUse),
75 ToolUseJsonParseError {
76 id: LanguageModelToolUseId,
77 tool_name: Arc<str>,
78 raw_input: Arc<str>,
79 json_parse_error: String,
80 },
81 StartMessage {
82 message_id: String,
83 },
84 ReasoningDetails(serde_json::Value),
85 UsageUpdate(TokenUsage),
86}
87
88impl LanguageModelCompletionEvent {
89 pub fn from_completion_request_status(
90 status: CompletionRequestStatus,
91 upstream_provider: LanguageModelProviderName,
92 ) -> Result<Option<Self>, LanguageModelCompletionError> {
93 match status {
94 CompletionRequestStatus::Queued { position } => {
95 Ok(Some(LanguageModelCompletionEvent::Queued { position }))
96 }
97 CompletionRequestStatus::Started => Ok(Some(LanguageModelCompletionEvent::Started)),
98 CompletionRequestStatus::Unknown | CompletionRequestStatus::StreamEnded => Ok(None),
99 CompletionRequestStatus::Failed {
100 code,
101 message,
102 request_id: _,
103 retry_after,
104 } => Err(LanguageModelCompletionError::from_cloud_failure(
105 upstream_provider,
106 code,
107 message,
108 retry_after.map(Duration::from_secs_f64),
109 )),
110 }
111 }
112}
113
114#[derive(Error, Debug)]
115pub enum LanguageModelCompletionError {
116 #[error("prompt too large for context window")]
117 PromptTooLarge { tokens: Option<u64> },
118 #[error("missing {provider} API key")]
119 NoApiKey { provider: LanguageModelProviderName },
120 #[error("{provider}'s API rate limit exceeded")]
121 RateLimitExceeded {
122 provider: LanguageModelProviderName,
123 retry_after: Option<Duration>,
124 },
125 #[error("{provider}'s API servers are overloaded right now")]
126 ServerOverloaded {
127 provider: LanguageModelProviderName,
128 retry_after: Option<Duration>,
129 },
130 #[error("{provider}'s API server reported an internal server error: {message}")]
131 ApiInternalServerError {
132 provider: LanguageModelProviderName,
133 message: String,
134 },
135 #[error("{message}")]
136 UpstreamProviderError {
137 message: String,
138 status: StatusCode,
139 retry_after: Option<Duration>,
140 },
141 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
142 HttpResponseError {
143 provider: LanguageModelProviderName,
144 status_code: StatusCode,
145 message: String,
146 },
147
148 // Client errors
149 #[error("invalid request format to {provider}'s API: {message}")]
150 BadRequestFormat {
151 provider: LanguageModelProviderName,
152 message: String,
153 },
154 #[error("authentication error with {provider}'s API: {message}")]
155 AuthenticationError {
156 provider: LanguageModelProviderName,
157 message: String,
158 },
159 #[error("Permission error with {provider}'s API: {message}")]
160 PermissionError {
161 provider: LanguageModelProviderName,
162 message: String,
163 },
164 #[error("language model provider API endpoint not found")]
165 ApiEndpointNotFound { provider: LanguageModelProviderName },
166 #[error("I/O error reading response from {provider}'s API")]
167 ApiReadResponseError {
168 provider: LanguageModelProviderName,
169 #[source]
170 error: io::Error,
171 },
172 #[error("error serializing request to {provider} API")]
173 SerializeRequest {
174 provider: LanguageModelProviderName,
175 #[source]
176 error: serde_json::Error,
177 },
178 #[error("error building request body to {provider} API")]
179 BuildRequestBody {
180 provider: LanguageModelProviderName,
181 #[source]
182 error: http::Error,
183 },
184 #[error("error sending HTTP request to {provider} API")]
185 HttpSend {
186 provider: LanguageModelProviderName,
187 #[source]
188 error: anyhow::Error,
189 },
190 #[error("error deserializing {provider} API response")]
191 DeserializeResponse {
192 provider: LanguageModelProviderName,
193 #[source]
194 error: serde_json::Error,
195 },
196
197 #[error("stream from {provider} ended unexpectedly")]
198 StreamEndedUnexpectedly { provider: LanguageModelProviderName },
199
200 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
201 #[error(transparent)]
202 Other(#[from] anyhow::Error),
203}
204
205impl LanguageModelCompletionError {
206 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
207 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
208 let upstream_status = error_json
209 .get("upstream_status")
210 .and_then(|v| v.as_u64())
211 .and_then(|status| u16::try_from(status).ok())
212 .and_then(|status| StatusCode::from_u16(status).ok())?;
213 let inner_message = error_json
214 .get("message")
215 .and_then(|v| v.as_str())
216 .unwrap_or(message)
217 .to_string();
218 Some((upstream_status, inner_message))
219 }
220
221 pub fn from_cloud_failure(
222 upstream_provider: LanguageModelProviderName,
223 code: String,
224 message: String,
225 retry_after: Option<Duration>,
226 ) -> Self {
227 if let Some(tokens) = parse_prompt_too_long(&message) {
228 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
229 // to be reported. This is a temporary workaround to handle this in the case where the
230 // token limit has been exceeded.
231 Self::PromptTooLarge {
232 tokens: Some(tokens),
233 }
234 } else if code == "upstream_http_error" {
235 if let Some((upstream_status, inner_message)) =
236 Self::parse_upstream_error_json(&message)
237 {
238 return Self::from_http_status(
239 upstream_provider,
240 upstream_status,
241 inner_message,
242 retry_after,
243 );
244 }
245 anyhow!("completion request failed, code: {code}, message: {message}").into()
246 } else if let Some(status_code) = code
247 .strip_prefix("upstream_http_")
248 .and_then(|code| StatusCode::from_str(code).ok())
249 {
250 Self::from_http_status(upstream_provider, status_code, message, retry_after)
251 } else if let Some(status_code) = code
252 .strip_prefix("http_")
253 .and_then(|code| StatusCode::from_str(code).ok())
254 {
255 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
256 } else {
257 anyhow!("completion request failed, code: {code}, message: {message}").into()
258 }
259 }
260
261 pub fn from_http_status(
262 provider: LanguageModelProviderName,
263 status_code: StatusCode,
264 message: String,
265 retry_after: Option<Duration>,
266 ) -> Self {
267 match status_code {
268 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
269 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
270 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
271 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
272 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
273 tokens: parse_prompt_too_long(&message),
274 },
275 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
276 provider,
277 retry_after,
278 },
279 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
280 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
281 provider,
282 retry_after,
283 },
284 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
285 provider,
286 retry_after,
287 },
288 _ => Self::HttpResponseError {
289 provider,
290 status_code,
291 message,
292 },
293 }
294 }
295}
296
297#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
298#[serde(rename_all = "snake_case")]
299pub enum StopReason {
300 EndTurn,
301 MaxTokens,
302 ToolUse,
303 Refusal,
304}
305
306#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
307pub struct TokenUsage {
308 #[serde(default, skip_serializing_if = "is_default")]
309 pub input_tokens: u64,
310 #[serde(default, skip_serializing_if = "is_default")]
311 pub output_tokens: u64,
312 #[serde(default, skip_serializing_if = "is_default")]
313 pub cache_creation_input_tokens: u64,
314 #[serde(default, skip_serializing_if = "is_default")]
315 pub cache_read_input_tokens: u64,
316}
317
318impl TokenUsage {
319 pub fn total_tokens(&self) -> u64 {
320 self.input_tokens
321 + self.output_tokens
322 + self.cache_read_input_tokens
323 + self.cache_creation_input_tokens
324 }
325}
326
327impl Add<TokenUsage> for TokenUsage {
328 type Output = Self;
329
330 fn add(self, other: Self) -> Self {
331 Self {
332 input_tokens: self.input_tokens + other.input_tokens,
333 output_tokens: self.output_tokens + other.output_tokens,
334 cache_creation_input_tokens: self.cache_creation_input_tokens
335 + other.cache_creation_input_tokens,
336 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
337 }
338 }
339}
340
341impl Sub<TokenUsage> for TokenUsage {
342 type Output = Self;
343
344 fn sub(self, other: Self) -> Self {
345 Self {
346 input_tokens: self.input_tokens - other.input_tokens,
347 output_tokens: self.output_tokens - other.output_tokens,
348 cache_creation_input_tokens: self.cache_creation_input_tokens
349 - other.cache_creation_input_tokens,
350 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
351 }
352 }
353}
354
355#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
356pub struct LanguageModelToolUseId(Arc<str>);
357
358impl fmt::Display for LanguageModelToolUseId {
359 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 write!(f, "{}", self.0)
361 }
362}
363
364impl<T> From<T> for LanguageModelToolUseId
365where
366 T: Into<Arc<str>>,
367{
368 fn from(value: T) -> Self {
369 Self(value.into())
370 }
371}
372
373#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
374pub struct LanguageModelToolUse {
375 pub id: LanguageModelToolUseId,
376 pub name: Arc<str>,
377 pub raw_input: String,
378 pub input: serde_json::Value,
379 pub is_input_complete: bool,
380 /// Thought signature the model sent us. Some models require that this
381 /// signature be preserved and sent back in conversation history for validation.
382 pub thought_signature: Option<String>,
383}
384
385pub struct LanguageModelTextStream {
386 pub message_id: Option<String>,
387 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
388 // Has complete token usage after the stream has finished
389 pub last_token_usage: Arc<Mutex<TokenUsage>>,
390}
391
392impl Default for LanguageModelTextStream {
393 fn default() -> Self {
394 Self {
395 message_id: None,
396 stream: Box::pin(futures::stream::empty()),
397 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
398 }
399 }
400}
401
402#[derive(Debug, Clone)]
403pub struct LanguageModelEffortLevel {
404 pub name: SharedString,
405 pub value: SharedString,
406 pub is_default: bool,
407}
408
409pub trait LanguageModel: Send + Sync {
410 fn id(&self) -> LanguageModelId;
411 fn name(&self) -> LanguageModelName;
412 fn provider_id(&self) -> LanguageModelProviderId;
413 fn provider_name(&self) -> LanguageModelProviderName;
414 fn upstream_provider_id(&self) -> LanguageModelProviderId {
415 self.provider_id()
416 }
417 fn upstream_provider_name(&self) -> LanguageModelProviderName {
418 self.provider_name()
419 }
420
421 /// Returns whether this model is the "latest", so we can highlight it in the UI.
422 fn is_latest(&self) -> bool {
423 false
424 }
425
426 fn telemetry_id(&self) -> String;
427
428 fn api_key(&self, _cx: &App) -> Option<String> {
429 None
430 }
431
432 /// Information about the cost of using this model, if available.
433 fn model_cost_info(&self) -> Option<LanguageModelCostInfo> {
434 None
435 }
436
437 /// Whether this model supports thinking.
438 fn supports_thinking(&self) -> bool {
439 false
440 }
441
442 fn supports_fast_mode(&self) -> bool {
443 false
444 }
445
446 /// Returns the list of supported effort levels that can be used when thinking.
447 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
448 Vec::new()
449 }
450
451 /// Returns the default effort level to use when thinking.
452 fn default_effort_level(&self) -> Option<LanguageModelEffortLevel> {
453 self.supported_effort_levels()
454 .into_iter()
455 .find(|effort_level| effort_level.is_default)
456 }
457
458 /// Whether this model supports images
459 fn supports_images(&self) -> bool;
460
461 /// Whether this model supports tools.
462 fn supports_tools(&self) -> bool;
463
464 /// Whether this model supports choosing which tool to use.
465 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
466
467 /// Returns whether this model or provider supports streaming tool calls;
468 fn supports_streaming_tools(&self) -> bool {
469 false
470 }
471
472 /// Returns whether this model/provider reports accurate split input/output token counts.
473 /// When true, the UI may show separate input/output token indicators.
474 fn supports_split_token_display(&self) -> bool {
475 false
476 }
477
478 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
479 LanguageModelToolSchemaFormat::JsonSchema
480 }
481
482 fn max_token_count(&self) -> u64;
483 fn max_output_tokens(&self) -> Option<u64> {
484 None
485 }
486
487 fn count_tokens(
488 &self,
489 request: LanguageModelRequest,
490 cx: &App,
491 ) -> BoxFuture<'static, Result<u64>>;
492
493 fn stream_completion(
494 &self,
495 request: LanguageModelRequest,
496 cx: &AsyncApp,
497 ) -> BoxFuture<
498 'static,
499 Result<
500 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
501 LanguageModelCompletionError,
502 >,
503 >;
504
505 fn stream_completion_text(
506 &self,
507 request: LanguageModelRequest,
508 cx: &AsyncApp,
509 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
510 let future = self.stream_completion(request, cx);
511
512 async move {
513 let events = future.await?;
514 let mut events = events.fuse();
515 let mut message_id = None;
516 let mut first_item_text = None;
517 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
518
519 if let Some(first_event) = events.next().await {
520 match first_event {
521 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
522 message_id = Some(id);
523 }
524 Ok(LanguageModelCompletionEvent::Text(text)) => {
525 first_item_text = Some(text);
526 }
527 _ => (),
528 }
529 }
530
531 let stream = futures::stream::iter(first_item_text.map(Ok))
532 .chain(events.filter_map({
533 let last_token_usage = last_token_usage.clone();
534 move |result| {
535 let last_token_usage = last_token_usage.clone();
536 async move {
537 match result {
538 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
539 Ok(LanguageModelCompletionEvent::Started) => None,
540 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
541 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
542 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
543 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
544 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
545 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
546 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
547 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
548 ..
549 }) => None,
550 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
551 *last_token_usage.lock() = token_usage;
552 None
553 }
554 Err(err) => Some(Err(err)),
555 }
556 }
557 }
558 }))
559 .boxed();
560
561 Ok(LanguageModelTextStream {
562 message_id,
563 stream,
564 last_token_usage,
565 })
566 }
567 .boxed()
568 }
569
570 fn stream_completion_tool(
571 &self,
572 request: LanguageModelRequest,
573 cx: &AsyncApp,
574 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
575 let future = self.stream_completion(request, cx);
576
577 async move {
578 let events = future.await?;
579 let mut events = events.fuse();
580
581 // Iterate through events until we find a complete ToolUse
582 while let Some(event) = events.next().await {
583 match event {
584 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
585 if tool_use.is_input_complete =>
586 {
587 return Ok(tool_use);
588 }
589 Err(err) => {
590 return Err(err);
591 }
592 _ => {}
593 }
594 }
595
596 // Stream ended without a complete tool use
597 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
598 "Stream ended without receiving a complete tool use"
599 )))
600 }
601 .boxed()
602 }
603
604 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
605 None
606 }
607
608 #[cfg(any(test, feature = "test-support"))]
609 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
610 unimplemented!()
611 }
612}
613
614impl std::fmt::Debug for dyn LanguageModel {
615 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
616 f.debug_struct("<dyn LanguageModel>")
617 .field("id", &self.id())
618 .field("name", &self.name())
619 .field("provider_id", &self.provider_id())
620 .field("provider_name", &self.provider_name())
621 .field("upstream_provider_name", &self.upstream_provider_name())
622 .field("upstream_provider_id", &self.upstream_provider_id())
623 .field("upstream_provider_id", &self.upstream_provider_id())
624 .field("supports_streaming_tools", &self.supports_streaming_tools())
625 .finish()
626 }
627}
628
629/// An error that occurred when trying to authenticate the language model provider.
630#[derive(Debug, Error)]
631pub enum AuthenticateError {
632 #[error("connection refused")]
633 ConnectionRefused,
634 #[error("credentials not found")]
635 CredentialsNotFound,
636 #[error(transparent)]
637 Other(#[from] anyhow::Error),
638}
639
640/// Either a built-in icon name or a path to an external SVG.
641#[derive(Debug, Clone, PartialEq, Eq)]
642pub enum IconOrSvg {
643 /// A built-in icon from Zed's icon set.
644 Icon(IconName),
645 /// Path to a custom SVG icon file.
646 Svg(SharedString),
647}
648
649impl Default for IconOrSvg {
650 fn default() -> Self {
651 Self::Icon(IconName::ZedAssistant)
652 }
653}
654
655pub trait LanguageModelProvider: 'static {
656 fn id(&self) -> LanguageModelProviderId;
657 fn name(&self) -> LanguageModelProviderName;
658 fn icon(&self) -> IconOrSvg {
659 IconOrSvg::default()
660 }
661 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
662 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
663 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
664 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
665 Vec::new()
666 }
667 fn is_authenticated(&self, cx: &App) -> bool;
668 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
669 fn configuration_view(
670 &self,
671 target_agent: ConfigurationViewTargetAgent,
672 window: &mut Window,
673 cx: &mut App,
674 ) -> AnyView;
675 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
676}
677
678#[derive(Default, Clone, PartialEq, Eq)]
679pub enum ConfigurationViewTargetAgent {
680 #[default]
681 ZedAgent,
682 Other(SharedString),
683}
684
685pub trait LanguageModelProviderState: 'static {
686 type ObservableEntity;
687
688 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
689
690 fn subscribe<T: 'static>(
691 &self,
692 cx: &mut gpui::Context<T>,
693 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
694 ) -> Option<gpui::Subscription> {
695 let entity = self.observable_entity()?;
696 Some(cx.observe(&entity, move |this, _, cx| {
697 callback(this, cx);
698 }))
699 }
700}
701
702#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
703pub struct LanguageModelId(pub SharedString);
704
705#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
706pub struct LanguageModelName(pub SharedString);
707
708#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
709pub struct LanguageModelProviderId(pub SharedString);
710
711#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
712pub struct LanguageModelProviderName(pub SharedString);
713
714#[derive(Clone, Debug, PartialEq)]
715pub enum LanguageModelCostInfo {
716 /// Cost per 1,000 input and output tokens
717 TokenCost {
718 input_token_cost_per_1m: f64,
719 output_token_cost_per_1m: f64,
720 },
721 /// Cost per request
722 RequestCost { cost_per_request: f64 },
723}
724
725impl LanguageModelCostInfo {
726 pub fn to_shared_string(&self) -> SharedString {
727 match self {
728 LanguageModelCostInfo::RequestCost { cost_per_request } => {
729 let cost_str = format!("{}×", Self::cost_value_to_string(cost_per_request));
730 SharedString::from(cost_str)
731 }
732 LanguageModelCostInfo::TokenCost {
733 input_token_cost_per_1m,
734 output_token_cost_per_1m,
735 } => {
736 let input_cost = Self::cost_value_to_string(input_token_cost_per_1m);
737 let output_cost = Self::cost_value_to_string(output_token_cost_per_1m);
738 SharedString::from(format!("{}$/{}$", input_cost, output_cost))
739 }
740 }
741 }
742
743 fn cost_value_to_string(cost: &f64) -> SharedString {
744 if (cost.fract() - 0.0).abs() < std::f64::EPSILON {
745 SharedString::from(format!("{:.0}", cost))
746 } else {
747 SharedString::from(format!("{:.2}", cost))
748 }
749 }
750}
751
752impl LanguageModelProviderId {
753 pub const fn new(id: &'static str) -> Self {
754 Self(SharedString::new_static(id))
755 }
756}
757
758impl LanguageModelProviderName {
759 pub const fn new(id: &'static str) -> Self {
760 Self(SharedString::new_static(id))
761 }
762}
763
764impl fmt::Display for LanguageModelProviderId {
765 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
766 write!(f, "{}", self.0)
767 }
768}
769
770impl fmt::Display for LanguageModelProviderName {
771 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
772 write!(f, "{}", self.0)
773 }
774}
775
776impl From<String> for LanguageModelId {
777 fn from(value: String) -> Self {
778 Self(SharedString::from(value))
779 }
780}
781
782impl From<String> for LanguageModelName {
783 fn from(value: String) -> Self {
784 Self(SharedString::from(value))
785 }
786}
787
788impl From<String> for LanguageModelProviderId {
789 fn from(value: String) -> Self {
790 Self(SharedString::from(value))
791 }
792}
793
794impl From<String> for LanguageModelProviderName {
795 fn from(value: String) -> Self {
796 Self(SharedString::from(value))
797 }
798}
799
800impl From<Arc<str>> for LanguageModelProviderId {
801 fn from(value: Arc<str>) -> Self {
802 Self(SharedString::from(value))
803 }
804}
805
806impl From<Arc<str>> for LanguageModelProviderName {
807 fn from(value: Arc<str>) -> Self {
808 Self(SharedString::from(value))
809 }
810}
811
812#[cfg(test)]
813mod tests {
814 use super::*;
815
816 #[test]
817 fn test_from_cloud_failure_with_upstream_http_error() {
818 let error = LanguageModelCompletionError::from_cloud_failure(
819 String::from("anthropic").into(),
820 "upstream_http_error".to_string(),
821 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
822 None,
823 );
824
825 match error {
826 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
827 assert_eq!(provider.0, "anthropic");
828 }
829 _ => panic!(
830 "Expected ServerOverloaded error for 503 status, got: {:?}",
831 error
832 ),
833 }
834
835 let error = LanguageModelCompletionError::from_cloud_failure(
836 String::from("anthropic").into(),
837 "upstream_http_error".to_string(),
838 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
839 None,
840 );
841
842 match error {
843 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
844 assert_eq!(provider.0, "anthropic");
845 assert_eq!(message, "Internal server error");
846 }
847 _ => panic!(
848 "Expected ApiInternalServerError for 500 status, got: {:?}",
849 error
850 ),
851 }
852 }
853
854 #[test]
855 fn test_from_cloud_failure_with_standard_format() {
856 let error = LanguageModelCompletionError::from_cloud_failure(
857 String::from("anthropic").into(),
858 "upstream_http_503".to_string(),
859 "Service unavailable".to_string(),
860 None,
861 );
862
863 match error {
864 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
865 assert_eq!(provider.0, "anthropic");
866 }
867 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
868 }
869 }
870
871 #[test]
872 fn test_upstream_http_error_connection_timeout() {
873 let error = LanguageModelCompletionError::from_cloud_failure(
874 String::from("anthropic").into(),
875 "upstream_http_error".to_string(),
876 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
877 None,
878 );
879
880 match error {
881 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
882 assert_eq!(provider.0, "anthropic");
883 }
884 _ => panic!(
885 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
886 error
887 ),
888 }
889
890 let error = LanguageModelCompletionError::from_cloud_failure(
891 String::from("anthropic").into(),
892 "upstream_http_error".to_string(),
893 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
894 None,
895 );
896
897 match error {
898 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
899 assert_eq!(provider.0, "anthropic");
900 assert_eq!(
901 message,
902 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
903 );
904 }
905 _ => panic!(
906 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
907 error
908 ),
909 }
910 }
911
912 #[test]
913 fn test_language_model_tool_use_serializes_with_signature() {
914 use serde_json::json;
915
916 let tool_use = LanguageModelToolUse {
917 id: LanguageModelToolUseId::from("test_id"),
918 name: "test_tool".into(),
919 raw_input: json!({"arg": "value"}).to_string(),
920 input: json!({"arg": "value"}),
921 is_input_complete: true,
922 thought_signature: Some("test_signature".to_string()),
923 };
924
925 let serialized = serde_json::to_value(&tool_use).unwrap();
926
927 assert_eq!(serialized["id"], "test_id");
928 assert_eq!(serialized["name"], "test_tool");
929 assert_eq!(serialized["thought_signature"], "test_signature");
930 }
931
932 #[test]
933 fn test_language_model_tool_use_deserializes_with_missing_signature() {
934 use serde_json::json;
935
936 let json = json!({
937 "id": "test_id",
938 "name": "test_tool",
939 "raw_input": "{\"arg\":\"value\"}",
940 "input": {"arg": "value"},
941 "is_input_complete": true
942 });
943
944 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
945
946 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
947 assert_eq!(tool_use.name.as_ref(), "test_tool");
948 assert_eq!(tool_use.thought_signature, None);
949 }
950
951 #[test]
952 fn test_language_model_tool_use_round_trip_with_signature() {
953 use serde_json::json;
954
955 let original = LanguageModelToolUse {
956 id: LanguageModelToolUseId::from("round_trip_id"),
957 name: "round_trip_tool".into(),
958 raw_input: json!({"key": "value"}).to_string(),
959 input: json!({"key": "value"}),
960 is_input_complete: true,
961 thought_signature: Some("round_trip_sig".to_string()),
962 };
963
964 let serialized = serde_json::to_value(&original).unwrap();
965 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
966
967 assert_eq!(deserialized.id, original.id);
968 assert_eq!(deserialized.name, original.name);
969 assert_eq!(deserialized.thought_signature, original.thought_signature);
970 }
971
972 #[test]
973 fn test_language_model_tool_use_round_trip_without_signature() {
974 use serde_json::json;
975
976 let original = LanguageModelToolUse {
977 id: LanguageModelToolUseId::from("no_sig_id"),
978 name: "no_sig_tool".into(),
979 raw_input: json!({"arg": "value"}).to_string(),
980 input: json!({"arg": "value"}),
981 is_input_complete: true,
982 thought_signature: None,
983 };
984
985 let serialized = serde_json::to_value(&original).unwrap();
986 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
987
988 assert_eq!(deserialized.id, original.id);
989 assert_eq!(deserialized.name, original.name);
990 assert_eq!(deserialized.thought_signature, None);
991 }
992}