1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::{Result, anyhow};
13use client::Client;
14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
15use futures::FutureExt;
16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
17use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
18use http_client::{StatusCode, http};
19use icons::IconName;
20use parking_lot::Mutex;
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize, de::DeserializeOwned};
23use std::any::Any;
24use std::ops::{Add, Sub};
25use std::str::FromStr;
26use std::sync::Arc;
27use std::time::Duration;
28use std::{fmt, io};
29use thiserror::Error;
30use util::serde::is_default;
31
32pub use crate::model::*;
33pub use crate::rate_limiter::*;
34pub use crate::registry::*;
35pub use crate::request::*;
36pub use crate::role::*;
37pub use crate::telemetry::*;
38
39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
40 LanguageModelProviderId::new("anthropic");
41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
42 LanguageModelProviderName::new("Anthropic");
43
44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
46 LanguageModelProviderName::new("Google AI");
47
48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
50 LanguageModelProviderName::new("OpenAI");
51
52pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
53pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
54 LanguageModelProviderName::new("Zed");
55
56pub fn init(client: Arc<Client>, cx: &mut App) {
57 init_settings(cx);
58 RefreshLlmTokenListener::register(client.clone(), cx);
59}
60
61pub fn init_settings(cx: &mut App) {
62 registry::init(cx);
63}
64
65/// Configuration for caching language model messages.
66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
67pub struct LanguageModelCacheConfiguration {
68 pub max_cache_anchors: usize,
69 pub should_speculate: bool,
70 pub min_total_token: u64,
71}
72
73/// A completion event from a language model.
74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
75pub enum LanguageModelCompletionEvent {
76 StatusUpdate(CompletionRequestStatus),
77 Stop(StopReason),
78 Text(String),
79 Thinking {
80 text: String,
81 signature: Option<String>,
82 },
83 RedactedThinking {
84 data: String,
85 },
86 ToolUse(LanguageModelToolUse),
87 ToolUseJsonParseError {
88 id: LanguageModelToolUseId,
89 tool_name: Arc<str>,
90 raw_input: Arc<str>,
91 json_parse_error: String,
92 },
93 StartMessage {
94 message_id: String,
95 },
96 UsageUpdate(TokenUsage),
97}
98
99#[derive(Error, Debug)]
100pub enum LanguageModelCompletionError {
101 #[error("prompt too large for context window")]
102 PromptTooLarge { tokens: Option<u64> },
103 #[error("missing {provider} API key")]
104 NoApiKey { provider: LanguageModelProviderName },
105 #[error("{provider}'s API rate limit exceeded")]
106 RateLimitExceeded {
107 provider: LanguageModelProviderName,
108 retry_after: Option<Duration>,
109 },
110 #[error("{provider}'s API servers are overloaded right now")]
111 ServerOverloaded {
112 provider: LanguageModelProviderName,
113 retry_after: Option<Duration>,
114 },
115 #[error("{provider}'s API server reported an internal server error: {message}")]
116 ApiInternalServerError {
117 provider: LanguageModelProviderName,
118 message: String,
119 },
120 #[error("{message}")]
121 UpstreamProviderError {
122 message: String,
123 status: StatusCode,
124 retry_after: Option<Duration>,
125 },
126 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
127 HttpResponseError {
128 provider: LanguageModelProviderName,
129 status_code: StatusCode,
130 message: String,
131 },
132
133 // Client errors
134 #[error("invalid request format to {provider}'s API: {message}")]
135 BadRequestFormat {
136 provider: LanguageModelProviderName,
137 message: String,
138 },
139 #[error("authentication error with {provider}'s API: {message}")]
140 AuthenticationError {
141 provider: LanguageModelProviderName,
142 message: String,
143 },
144 #[error("permission error with {provider}'s API: {message}")]
145 PermissionError {
146 provider: LanguageModelProviderName,
147 message: String,
148 },
149 #[error("language model provider API endpoint not found")]
150 ApiEndpointNotFound { provider: LanguageModelProviderName },
151 #[error("I/O error reading response from {provider}'s API")]
152 ApiReadResponseError {
153 provider: LanguageModelProviderName,
154 #[source]
155 error: io::Error,
156 },
157 #[error("error serializing request to {provider} API")]
158 SerializeRequest {
159 provider: LanguageModelProviderName,
160 #[source]
161 error: serde_json::Error,
162 },
163 #[error("error building request body to {provider} API")]
164 BuildRequestBody {
165 provider: LanguageModelProviderName,
166 #[source]
167 error: http::Error,
168 },
169 #[error("error sending HTTP request to {provider} API")]
170 HttpSend {
171 provider: LanguageModelProviderName,
172 #[source]
173 error: anyhow::Error,
174 },
175 #[error("error deserializing {provider} API response")]
176 DeserializeResponse {
177 provider: LanguageModelProviderName,
178 #[source]
179 error: serde_json::Error,
180 },
181
182 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
183 #[error(transparent)]
184 Other(#[from] anyhow::Error),
185}
186
187impl LanguageModelCompletionError {
188 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
189 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
190 let upstream_status = error_json
191 .get("upstream_status")
192 .and_then(|v| v.as_u64())
193 .and_then(|status| u16::try_from(status).ok())
194 .and_then(|status| StatusCode::from_u16(status).ok())?;
195 let inner_message = error_json
196 .get("message")
197 .and_then(|v| v.as_str())
198 .unwrap_or(message)
199 .to_string();
200 Some((upstream_status, inner_message))
201 }
202
203 pub fn from_cloud_failure(
204 upstream_provider: LanguageModelProviderName,
205 code: String,
206 message: String,
207 retry_after: Option<Duration>,
208 ) -> Self {
209 if let Some(tokens) = parse_prompt_too_long(&message) {
210 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
211 // to be reported. This is a temporary workaround to handle this in the case where the
212 // token limit has been exceeded.
213 Self::PromptTooLarge {
214 tokens: Some(tokens),
215 }
216 } else if code == "upstream_http_error" {
217 if let Some((upstream_status, inner_message)) =
218 Self::parse_upstream_error_json(&message)
219 {
220 return Self::from_http_status(
221 upstream_provider,
222 upstream_status,
223 inner_message,
224 retry_after,
225 );
226 }
227 anyhow!("completion request failed, code: {code}, message: {message}").into()
228 } else if let Some(status_code) = code
229 .strip_prefix("upstream_http_")
230 .and_then(|code| StatusCode::from_str(code).ok())
231 {
232 Self::from_http_status(upstream_provider, status_code, message, retry_after)
233 } else if let Some(status_code) = code
234 .strip_prefix("http_")
235 .and_then(|code| StatusCode::from_str(code).ok())
236 {
237 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
238 } else {
239 anyhow!("completion request failed, code: {code}, message: {message}").into()
240 }
241 }
242
243 pub fn from_http_status(
244 provider: LanguageModelProviderName,
245 status_code: StatusCode,
246 message: String,
247 retry_after: Option<Duration>,
248 ) -> Self {
249 match status_code {
250 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
251 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
252 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
253 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
254 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
255 tokens: parse_prompt_too_long(&message),
256 },
257 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
258 provider,
259 retry_after,
260 },
261 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
262 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
263 provider,
264 retry_after,
265 },
266 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
267 provider,
268 retry_after,
269 },
270 _ => Self::HttpResponseError {
271 provider,
272 status_code,
273 message,
274 },
275 }
276 }
277}
278
279impl From<AnthropicError> for LanguageModelCompletionError {
280 fn from(error: AnthropicError) -> Self {
281 let provider = ANTHROPIC_PROVIDER_NAME;
282 match error {
283 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
284 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
285 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
286 AnthropicError::DeserializeResponse(error) => {
287 Self::DeserializeResponse { provider, error }
288 }
289 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
290 AnthropicError::HttpResponseError {
291 status_code,
292 message,
293 } => Self::HttpResponseError {
294 provider,
295 status_code,
296 message,
297 },
298 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
299 provider,
300 retry_after: Some(retry_after),
301 },
302 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
303 provider,
304 retry_after: retry_after,
305 },
306 AnthropicError::ApiError(api_error) => api_error.into(),
307 }
308 }
309}
310
311impl From<anthropic::ApiError> for LanguageModelCompletionError {
312 fn from(error: anthropic::ApiError) -> Self {
313 use anthropic::ApiErrorCode::*;
314 let provider = ANTHROPIC_PROVIDER_NAME;
315 match error.code() {
316 Some(code) => match code {
317 InvalidRequestError => Self::BadRequestFormat {
318 provider,
319 message: error.message,
320 },
321 AuthenticationError => Self::AuthenticationError {
322 provider,
323 message: error.message,
324 },
325 PermissionError => Self::PermissionError {
326 provider,
327 message: error.message,
328 },
329 NotFoundError => Self::ApiEndpointNotFound { provider },
330 RequestTooLarge => Self::PromptTooLarge {
331 tokens: parse_prompt_too_long(&error.message),
332 },
333 RateLimitError => Self::RateLimitExceeded {
334 provider,
335 retry_after: None,
336 },
337 ApiError => Self::ApiInternalServerError {
338 provider,
339 message: error.message,
340 },
341 OverloadedError => Self::ServerOverloaded {
342 provider,
343 retry_after: None,
344 },
345 },
346 None => Self::Other(error.into()),
347 }
348 }
349}
350
351/// Indicates the format used to define the input schema for a language model tool.
352#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
353pub enum LanguageModelToolSchemaFormat {
354 /// A JSON schema, see https://json-schema.org
355 JsonSchema,
356 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
357 JsonSchemaSubset,
358}
359
360#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
361#[serde(rename_all = "snake_case")]
362pub enum StopReason {
363 EndTurn,
364 MaxTokens,
365 ToolUse,
366 Refusal,
367}
368
369#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
370pub struct TokenUsage {
371 #[serde(default, skip_serializing_if = "is_default")]
372 pub input_tokens: u64,
373 #[serde(default, skip_serializing_if = "is_default")]
374 pub output_tokens: u64,
375 #[serde(default, skip_serializing_if = "is_default")]
376 pub cache_creation_input_tokens: u64,
377 #[serde(default, skip_serializing_if = "is_default")]
378 pub cache_read_input_tokens: u64,
379}
380
381impl TokenUsage {
382 pub fn total_tokens(&self) -> u64 {
383 self.input_tokens
384 + self.output_tokens
385 + self.cache_read_input_tokens
386 + self.cache_creation_input_tokens
387 }
388}
389
390impl Add<TokenUsage> for TokenUsage {
391 type Output = Self;
392
393 fn add(self, other: Self) -> Self {
394 Self {
395 input_tokens: self.input_tokens + other.input_tokens,
396 output_tokens: self.output_tokens + other.output_tokens,
397 cache_creation_input_tokens: self.cache_creation_input_tokens
398 + other.cache_creation_input_tokens,
399 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
400 }
401 }
402}
403
404impl Sub<TokenUsage> for TokenUsage {
405 type Output = Self;
406
407 fn sub(self, other: Self) -> Self {
408 Self {
409 input_tokens: self.input_tokens - other.input_tokens,
410 output_tokens: self.output_tokens - other.output_tokens,
411 cache_creation_input_tokens: self.cache_creation_input_tokens
412 - other.cache_creation_input_tokens,
413 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
414 }
415 }
416}
417
418#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
419pub struct LanguageModelToolUseId(Arc<str>);
420
421impl fmt::Display for LanguageModelToolUseId {
422 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
423 write!(f, "{}", self.0)
424 }
425}
426
427impl<T> From<T> for LanguageModelToolUseId
428where
429 T: Into<Arc<str>>,
430{
431 fn from(value: T) -> Self {
432 Self(value.into())
433 }
434}
435
436#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
437pub struct LanguageModelToolUse {
438 pub id: LanguageModelToolUseId,
439 pub name: Arc<str>,
440 pub raw_input: String,
441 pub input: serde_json::Value,
442 pub is_input_complete: bool,
443}
444
445pub struct LanguageModelTextStream {
446 pub message_id: Option<String>,
447 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
448 // Has complete token usage after the stream has finished
449 pub last_token_usage: Arc<Mutex<TokenUsage>>,
450}
451
452impl Default for LanguageModelTextStream {
453 fn default() -> Self {
454 Self {
455 message_id: None,
456 stream: Box::pin(futures::stream::empty()),
457 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
458 }
459 }
460}
461
462pub trait LanguageModel: Send + Sync {
463 fn id(&self) -> LanguageModelId;
464 fn name(&self) -> LanguageModelName;
465 fn provider_id(&self) -> LanguageModelProviderId;
466 fn provider_name(&self) -> LanguageModelProviderName;
467 fn upstream_provider_id(&self) -> LanguageModelProviderId {
468 self.provider_id()
469 }
470 fn upstream_provider_name(&self) -> LanguageModelProviderName {
471 self.provider_name()
472 }
473
474 fn telemetry_id(&self) -> String;
475
476 fn api_key(&self, _cx: &App) -> Option<String> {
477 None
478 }
479
480 /// Whether this model supports images
481 fn supports_images(&self) -> bool;
482
483 /// Whether this model supports tools.
484 fn supports_tools(&self) -> bool;
485
486 /// Whether this model supports choosing which tool to use.
487 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
488
489 /// Returns whether this model supports "burn mode";
490 fn supports_burn_mode(&self) -> bool {
491 false
492 }
493
494 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
495 LanguageModelToolSchemaFormat::JsonSchema
496 }
497
498 fn max_token_count(&self) -> u64;
499 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
500 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
501 None
502 }
503 fn max_output_tokens(&self) -> Option<u64> {
504 None
505 }
506
507 fn count_tokens(
508 &self,
509 request: LanguageModelRequest,
510 cx: &App,
511 ) -> BoxFuture<'static, Result<u64>>;
512
513 fn stream_completion(
514 &self,
515 request: LanguageModelRequest,
516 cx: &AsyncApp,
517 ) -> BoxFuture<
518 'static,
519 Result<
520 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
521 LanguageModelCompletionError,
522 >,
523 >;
524
525 fn stream_completion_text(
526 &self,
527 request: LanguageModelRequest,
528 cx: &AsyncApp,
529 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
530 let future = self.stream_completion(request, cx);
531
532 async move {
533 let events = future.await?;
534 let mut events = events.fuse();
535 let mut message_id = None;
536 let mut first_item_text = None;
537 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
538
539 if let Some(first_event) = events.next().await {
540 match first_event {
541 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
542 message_id = Some(id.clone());
543 }
544 Ok(LanguageModelCompletionEvent::Text(text)) => {
545 first_item_text = Some(text);
546 }
547 _ => (),
548 }
549 }
550
551 let stream = futures::stream::iter(first_item_text.map(Ok))
552 .chain(events.filter_map({
553 let last_token_usage = last_token_usage.clone();
554 move |result| {
555 let last_token_usage = last_token_usage.clone();
556 async move {
557 match result {
558 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
559 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
560 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
561 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
562 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
563 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
564 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
565 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
566 ..
567 }) => None,
568 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
569 *last_token_usage.lock() = token_usage;
570 None
571 }
572 Err(err) => Some(Err(err)),
573 }
574 }
575 }
576 }))
577 .boxed();
578
579 Ok(LanguageModelTextStream {
580 message_id,
581 stream,
582 last_token_usage,
583 })
584 }
585 .boxed()
586 }
587
588 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
589 None
590 }
591
592 #[cfg(any(test, feature = "test-support"))]
593 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
594 unimplemented!()
595 }
596}
597
598pub trait LanguageModelExt: LanguageModel {
599 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
600 match mode {
601 CompletionMode::Normal => self.max_token_count(),
602 CompletionMode::Max => self
603 .max_token_count_in_burn_mode()
604 .unwrap_or_else(|| self.max_token_count()),
605 }
606 }
607}
608impl LanguageModelExt for dyn LanguageModel {}
609
610pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
611 fn name() -> String;
612 fn description() -> String;
613}
614
615/// An error that occurred when trying to authenticate the language model provider.
616#[derive(Debug, Error)]
617pub enum AuthenticateError {
618 #[error("credentials not found")]
619 CredentialsNotFound,
620 #[error(transparent)]
621 Other(#[from] anyhow::Error),
622}
623
624pub trait LanguageModelProvider: Any + Send + Sync {
625 fn id(&self) -> LanguageModelProviderId;
626 fn name(&self) -> LanguageModelProviderName;
627 fn icon(&self) -> IconName {
628 IconName::ZedAssistant
629 }
630 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
631 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
632 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
633 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
634 Vec::new()
635 }
636 fn is_authenticated(&self, cx: &App) -> bool;
637 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
638 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
639 fn must_accept_terms(&self, _cx: &App) -> bool {
640 false
641 }
642 fn render_accept_terms(
643 &self,
644 _view: LanguageModelProviderTosView,
645 _cx: &mut App,
646 ) -> Option<AnyElement> {
647 None
648 }
649 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
650}
651
652#[derive(PartialEq, Eq)]
653pub enum LanguageModelProviderTosView {
654 /// When there are some past interactions in the Agent Panel.
655 ThreadEmptyState,
656 /// When there are no past interactions in the Agent Panel.
657 ThreadFreshStart,
658 TextThreadPopup,
659 Configuration,
660}
661
662pub trait LanguageModelProviderState: 'static {
663 type ObservableEntity;
664
665 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
666
667 fn subscribe<T: 'static>(
668 &self,
669 cx: &mut gpui::Context<T>,
670 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
671 ) -> Option<gpui::Subscription> {
672 let entity = self.observable_entity()?;
673 Some(cx.observe(&entity, move |this, _, cx| {
674 callback(this, cx);
675 }))
676 }
677}
678
679#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
680pub struct LanguageModelId(pub SharedString);
681
682#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
683pub struct LanguageModelName(pub SharedString);
684
685#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
686pub struct LanguageModelProviderId(pub SharedString);
687
688#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
689pub struct LanguageModelProviderName(pub SharedString);
690
691impl LanguageModelProviderId {
692 pub const fn new(id: &'static str) -> Self {
693 Self(SharedString::new_static(id))
694 }
695}
696
697impl LanguageModelProviderName {
698 pub const fn new(id: &'static str) -> Self {
699 Self(SharedString::new_static(id))
700 }
701}
702
703impl fmt::Display for LanguageModelProviderId {
704 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
705 write!(f, "{}", self.0)
706 }
707}
708
709impl fmt::Display for LanguageModelProviderName {
710 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
711 write!(f, "{}", self.0)
712 }
713}
714
715impl From<String> for LanguageModelId {
716 fn from(value: String) -> Self {
717 Self(SharedString::from(value))
718 }
719}
720
721impl From<String> for LanguageModelName {
722 fn from(value: String) -> Self {
723 Self(SharedString::from(value))
724 }
725}
726
727impl From<String> for LanguageModelProviderId {
728 fn from(value: String) -> Self {
729 Self(SharedString::from(value))
730 }
731}
732
733impl From<String> for LanguageModelProviderName {
734 fn from(value: String) -> Self {
735 Self(SharedString::from(value))
736 }
737}
738
739impl From<Arc<str>> for LanguageModelProviderId {
740 fn from(value: Arc<str>) -> Self {
741 Self(SharedString::from(value))
742 }
743}
744
745impl From<Arc<str>> for LanguageModelProviderName {
746 fn from(value: Arc<str>) -> Self {
747 Self(SharedString::from(value))
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754
755 #[test]
756 fn test_from_cloud_failure_with_upstream_http_error() {
757 let error = LanguageModelCompletionError::from_cloud_failure(
758 String::from("anthropic").into(),
759 "upstream_http_error".to_string(),
760 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
761 None,
762 );
763
764 match error {
765 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
766 assert_eq!(provider.0, "anthropic");
767 }
768 _ => panic!(
769 "Expected ServerOverloaded error for 503 status, got: {:?}",
770 error
771 ),
772 }
773
774 let error = LanguageModelCompletionError::from_cloud_failure(
775 String::from("anthropic").into(),
776 "upstream_http_error".to_string(),
777 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
778 None,
779 );
780
781 match error {
782 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
783 assert_eq!(provider.0, "anthropic");
784 assert_eq!(message, "Internal server error");
785 }
786 _ => panic!(
787 "Expected ApiInternalServerError for 500 status, got: {:?}",
788 error
789 ),
790 }
791 }
792
793 #[test]
794 fn test_from_cloud_failure_with_standard_format() {
795 let error = LanguageModelCompletionError::from_cloud_failure(
796 String::from("anthropic").into(),
797 "upstream_http_503".to_string(),
798 "Service unavailable".to_string(),
799 None,
800 );
801
802 match error {
803 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
804 assert_eq!(provider.0, "anthropic");
805 }
806 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
807 }
808 }
809
810 #[test]
811 fn test_upstream_http_error_connection_timeout() {
812 let error = LanguageModelCompletionError::from_cloud_failure(
813 String::from("anthropic").into(),
814 "upstream_http_error".to_string(),
815 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
816 None,
817 );
818
819 match error {
820 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
821 assert_eq!(provider.0, "anthropic");
822 }
823 _ => panic!(
824 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
825 error
826 ),
827 }
828
829 let error = LanguageModelCompletionError::from_cloud_failure(
830 String::from("anthropic").into(),
831 "upstream_http_error".to_string(),
832 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
833 None,
834 );
835
836 match error {
837 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
838 assert_eq!(provider.0, "anthropic");
839 assert_eq!(
840 message,
841 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
842 );
843 }
844 _ => panic!(
845 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
846 error
847 ),
848 }
849 }
850}