1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::{Result, anyhow};
13use client::Client;
14use futures::FutureExt;
15use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
16use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
17use http_client::{StatusCode, http};
18use icons::IconName;
19use parking_lot::Mutex;
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::ops::{Add, Sub};
23use std::str::FromStr;
24use std::sync::Arc;
25use std::time::Duration;
26use std::{fmt, io};
27use thiserror::Error;
28use util::serde::is_default;
29use zed_llm_client::{CompletionMode, CompletionRequestStatus};
30
31pub use crate::model::*;
32pub use crate::rate_limiter::*;
33pub use crate::registry::*;
34pub use crate::request::*;
35pub use crate::role::*;
36pub use crate::telemetry::*;
37
38pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
39 LanguageModelProviderId::new("anthropic");
40pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
41 LanguageModelProviderName::new("Anthropic");
42
43pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
44pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
45 LanguageModelProviderName::new("Google AI");
46
47pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
48pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
49 LanguageModelProviderName::new("OpenAI");
50
51pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
52pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
53 LanguageModelProviderName::new("Zed");
54
55pub fn init(client: Arc<Client>, cx: &mut App) {
56 init_settings(cx);
57 RefreshLlmTokenListener::register(client.clone(), cx);
58}
59
60pub fn init_settings(cx: &mut App) {
61 registry::init(cx);
62}
63
64/// Configuration for caching language model messages.
65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
66pub struct LanguageModelCacheConfiguration {
67 pub max_cache_anchors: usize,
68 pub should_speculate: bool,
69 pub min_total_token: u64,
70}
71
72/// A completion event from a language model.
73#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
74pub enum LanguageModelCompletionEvent {
75 StatusUpdate(CompletionRequestStatus),
76 Stop(StopReason),
77 Text(String),
78 Thinking {
79 text: String,
80 signature: Option<String>,
81 },
82 RedactedThinking {
83 data: String,
84 },
85 ToolUse(LanguageModelToolUse),
86 ToolUseJsonParseError {
87 id: LanguageModelToolUseId,
88 tool_name: Arc<str>,
89 raw_input: Arc<str>,
90 json_parse_error: String,
91 },
92 StartMessage {
93 message_id: String,
94 },
95 UsageUpdate(TokenUsage),
96}
97
98#[derive(Error, Debug)]
99pub enum LanguageModelCompletionError {
100 #[error("prompt too large for context window")]
101 PromptTooLarge { tokens: Option<u64> },
102 #[error("missing {provider} API key")]
103 NoApiKey { provider: LanguageModelProviderName },
104 #[error("{provider}'s API rate limit exceeded")]
105 RateLimitExceeded {
106 provider: LanguageModelProviderName,
107 retry_after: Option<Duration>,
108 },
109 #[error("{provider}'s API servers are overloaded right now")]
110 ServerOverloaded {
111 provider: LanguageModelProviderName,
112 retry_after: Option<Duration>,
113 },
114 #[error("{provider}'s API server reported an internal server error: {message}")]
115 ApiInternalServerError {
116 provider: LanguageModelProviderName,
117 message: String,
118 },
119 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
120 HttpResponseError {
121 provider: LanguageModelProviderName,
122 status_code: StatusCode,
123 message: String,
124 },
125
126 // Client errors
127 #[error("invalid request format to {provider}'s API: {message}")]
128 BadRequestFormat {
129 provider: LanguageModelProviderName,
130 message: String,
131 },
132 #[error("authentication error with {provider}'s API: {message}")]
133 AuthenticationError {
134 provider: LanguageModelProviderName,
135 message: String,
136 },
137 #[error("permission error with {provider}'s API: {message}")]
138 PermissionError {
139 provider: LanguageModelProviderName,
140 message: String,
141 },
142 #[error("language model provider API endpoint not found")]
143 ApiEndpointNotFound { provider: LanguageModelProviderName },
144 #[error("I/O error reading response from {provider}'s API")]
145 ApiReadResponseError {
146 provider: LanguageModelProviderName,
147 #[source]
148 error: io::Error,
149 },
150 #[error("error serializing request to {provider} API")]
151 SerializeRequest {
152 provider: LanguageModelProviderName,
153 #[source]
154 error: serde_json::Error,
155 },
156 #[error("error building request body to {provider} API")]
157 BuildRequestBody {
158 provider: LanguageModelProviderName,
159 #[source]
160 error: http::Error,
161 },
162 #[error("error sending HTTP request to {provider} API")]
163 HttpSend {
164 provider: LanguageModelProviderName,
165 #[source]
166 error: anyhow::Error,
167 },
168 #[error("error deserializing {provider} API response")]
169 DeserializeResponse {
170 provider: LanguageModelProviderName,
171 #[source]
172 error: serde_json::Error,
173 },
174
175 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
176 #[error(transparent)]
177 Other(#[from] anyhow::Error),
178}
179
180impl LanguageModelCompletionError {
181 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
182 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
183 let upstream_status = error_json
184 .get("upstream_status")
185 .and_then(|v| v.as_u64())
186 .and_then(|status| u16::try_from(status).ok())
187 .and_then(|status| StatusCode::from_u16(status).ok())?;
188 let inner_message = error_json
189 .get("message")
190 .and_then(|v| v.as_str())
191 .unwrap_or(message)
192 .to_string();
193 Some((upstream_status, inner_message))
194 }
195
196 pub fn from_cloud_failure(
197 upstream_provider: LanguageModelProviderName,
198 code: String,
199 message: String,
200 retry_after: Option<Duration>,
201 ) -> Self {
202 if let Some(tokens) = parse_prompt_too_long(&message) {
203 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
204 // to be reported. This is a temporary workaround to handle this in the case where the
205 // token limit has been exceeded.
206 Self::PromptTooLarge {
207 tokens: Some(tokens),
208 }
209 } else if code == "upstream_http_error" {
210 if let Some((upstream_status, inner_message)) =
211 Self::parse_upstream_error_json(&message)
212 {
213 return Self::from_http_status(
214 upstream_provider,
215 upstream_status,
216 inner_message,
217 retry_after,
218 );
219 }
220 anyhow!("completion request failed, code: {code}, message: {message}").into()
221 } else if let Some(status_code) = code
222 .strip_prefix("upstream_http_")
223 .and_then(|code| StatusCode::from_str(code).ok())
224 {
225 Self::from_http_status(upstream_provider, status_code, message, retry_after)
226 } else if let Some(status_code) = code
227 .strip_prefix("http_")
228 .and_then(|code| StatusCode::from_str(code).ok())
229 {
230 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
231 } else {
232 anyhow!("completion request failed, code: {code}, message: {message}").into()
233 }
234 }
235
236 pub fn from_http_status(
237 provider: LanguageModelProviderName,
238 status_code: StatusCode,
239 message: String,
240 retry_after: Option<Duration>,
241 ) -> Self {
242 match status_code {
243 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
244 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
245 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
246 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
247 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
248 tokens: parse_prompt_too_long(&message),
249 },
250 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
251 provider,
252 retry_after,
253 },
254 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
255 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
256 provider,
257 retry_after,
258 },
259 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
260 provider,
261 retry_after,
262 },
263 _ => Self::HttpResponseError {
264 provider,
265 status_code,
266 message,
267 },
268 }
269 }
270}
271
272impl From<AnthropicError> for LanguageModelCompletionError {
273 fn from(error: AnthropicError) -> Self {
274 let provider = ANTHROPIC_PROVIDER_NAME;
275 match error {
276 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
277 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
278 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
279 AnthropicError::DeserializeResponse(error) => {
280 Self::DeserializeResponse { provider, error }
281 }
282 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
283 AnthropicError::HttpResponseError {
284 status_code,
285 message,
286 } => Self::HttpResponseError {
287 provider,
288 status_code,
289 message,
290 },
291 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
292 provider,
293 retry_after: Some(retry_after),
294 },
295 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
296 provider,
297 retry_after: retry_after,
298 },
299 AnthropicError::ApiError(api_error) => api_error.into(),
300 }
301 }
302}
303
304impl From<anthropic::ApiError> for LanguageModelCompletionError {
305 fn from(error: anthropic::ApiError) -> Self {
306 use anthropic::ApiErrorCode::*;
307 let provider = ANTHROPIC_PROVIDER_NAME;
308 match error.code() {
309 Some(code) => match code {
310 InvalidRequestError => Self::BadRequestFormat {
311 provider,
312 message: error.message,
313 },
314 AuthenticationError => Self::AuthenticationError {
315 provider,
316 message: error.message,
317 },
318 PermissionError => Self::PermissionError {
319 provider,
320 message: error.message,
321 },
322 NotFoundError => Self::ApiEndpointNotFound { provider },
323 RequestTooLarge => Self::PromptTooLarge {
324 tokens: parse_prompt_too_long(&error.message),
325 },
326 RateLimitError => Self::RateLimitExceeded {
327 provider,
328 retry_after: None,
329 },
330 ApiError => Self::ApiInternalServerError {
331 provider,
332 message: error.message,
333 },
334 OverloadedError => Self::ServerOverloaded {
335 provider,
336 retry_after: None,
337 },
338 },
339 None => Self::Other(error.into()),
340 }
341 }
342}
343
344/// Indicates the format used to define the input schema for a language model tool.
345#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
346pub enum LanguageModelToolSchemaFormat {
347 /// A JSON schema, see https://json-schema.org
348 JsonSchema,
349 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
350 JsonSchemaSubset,
351}
352
353#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
354#[serde(rename_all = "snake_case")]
355pub enum StopReason {
356 EndTurn,
357 MaxTokens,
358 ToolUse,
359 Refusal,
360}
361
362#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
363pub struct TokenUsage {
364 #[serde(default, skip_serializing_if = "is_default")]
365 pub input_tokens: u64,
366 #[serde(default, skip_serializing_if = "is_default")]
367 pub output_tokens: u64,
368 #[serde(default, skip_serializing_if = "is_default")]
369 pub cache_creation_input_tokens: u64,
370 #[serde(default, skip_serializing_if = "is_default")]
371 pub cache_read_input_tokens: u64,
372}
373
374impl TokenUsage {
375 pub fn total_tokens(&self) -> u64 {
376 self.input_tokens
377 + self.output_tokens
378 + self.cache_read_input_tokens
379 + self.cache_creation_input_tokens
380 }
381}
382
383impl Add<TokenUsage> for TokenUsage {
384 type Output = Self;
385
386 fn add(self, other: Self) -> Self {
387 Self {
388 input_tokens: self.input_tokens + other.input_tokens,
389 output_tokens: self.output_tokens + other.output_tokens,
390 cache_creation_input_tokens: self.cache_creation_input_tokens
391 + other.cache_creation_input_tokens,
392 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
393 }
394 }
395}
396
397impl Sub<TokenUsage> for TokenUsage {
398 type Output = Self;
399
400 fn sub(self, other: Self) -> Self {
401 Self {
402 input_tokens: self.input_tokens - other.input_tokens,
403 output_tokens: self.output_tokens - other.output_tokens,
404 cache_creation_input_tokens: self.cache_creation_input_tokens
405 - other.cache_creation_input_tokens,
406 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
407 }
408 }
409}
410
411#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
412pub struct LanguageModelToolUseId(Arc<str>);
413
414impl fmt::Display for LanguageModelToolUseId {
415 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416 write!(f, "{}", self.0)
417 }
418}
419
420impl<T> From<T> for LanguageModelToolUseId
421where
422 T: Into<Arc<str>>,
423{
424 fn from(value: T) -> Self {
425 Self(value.into())
426 }
427}
428
429#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
430pub struct LanguageModelToolUse {
431 pub id: LanguageModelToolUseId,
432 pub name: Arc<str>,
433 pub raw_input: String,
434 pub input: serde_json::Value,
435 pub is_input_complete: bool,
436}
437
438pub struct LanguageModelTextStream {
439 pub message_id: Option<String>,
440 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
441 // Has complete token usage after the stream has finished
442 pub last_token_usage: Arc<Mutex<TokenUsage>>,
443}
444
445impl Default for LanguageModelTextStream {
446 fn default() -> Self {
447 Self {
448 message_id: None,
449 stream: Box::pin(futures::stream::empty()),
450 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
451 }
452 }
453}
454
455pub trait LanguageModel: Send + Sync {
456 fn id(&self) -> LanguageModelId;
457 fn name(&self) -> LanguageModelName;
458 fn provider_id(&self) -> LanguageModelProviderId;
459 fn provider_name(&self) -> LanguageModelProviderName;
460 fn upstream_provider_id(&self) -> LanguageModelProviderId {
461 self.provider_id()
462 }
463 fn upstream_provider_name(&self) -> LanguageModelProviderName {
464 self.provider_name()
465 }
466
467 fn telemetry_id(&self) -> String;
468
469 fn api_key(&self, _cx: &App) -> Option<String> {
470 None
471 }
472
473 /// Whether this model supports images
474 fn supports_images(&self) -> bool;
475
476 /// Whether this model supports tools.
477 fn supports_tools(&self) -> bool;
478
479 /// Whether this model supports choosing which tool to use.
480 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
481
482 /// Returns whether this model supports "burn mode";
483 fn supports_burn_mode(&self) -> bool {
484 false
485 }
486
487 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
488 LanguageModelToolSchemaFormat::JsonSchema
489 }
490
491 fn max_token_count(&self) -> u64;
492 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
493 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
494 None
495 }
496 fn max_output_tokens(&self) -> Option<u64> {
497 None
498 }
499
500 fn count_tokens(
501 &self,
502 request: LanguageModelRequest,
503 cx: &App,
504 ) -> BoxFuture<'static, Result<u64>>;
505
506 fn stream_completion(
507 &self,
508 request: LanguageModelRequest,
509 cx: &AsyncApp,
510 ) -> BoxFuture<
511 'static,
512 Result<
513 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
514 LanguageModelCompletionError,
515 >,
516 >;
517
518 fn stream_completion_text(
519 &self,
520 request: LanguageModelRequest,
521 cx: &AsyncApp,
522 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
523 let future = self.stream_completion(request, cx);
524
525 async move {
526 let events = future.await?;
527 let mut events = events.fuse();
528 let mut message_id = None;
529 let mut first_item_text = None;
530 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
531
532 if let Some(first_event) = events.next().await {
533 match first_event {
534 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
535 message_id = Some(id.clone());
536 }
537 Ok(LanguageModelCompletionEvent::Text(text)) => {
538 first_item_text = Some(text);
539 }
540 _ => (),
541 }
542 }
543
544 let stream = futures::stream::iter(first_item_text.map(Ok))
545 .chain(events.filter_map({
546 let last_token_usage = last_token_usage.clone();
547 move |result| {
548 let last_token_usage = last_token_usage.clone();
549 async move {
550 match result {
551 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
552 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
553 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
554 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
555 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
556 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
557 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
558 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
559 ..
560 }) => None,
561 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
562 *last_token_usage.lock() = token_usage;
563 None
564 }
565 Err(err) => Some(Err(err)),
566 }
567 }
568 }
569 }))
570 .boxed();
571
572 Ok(LanguageModelTextStream {
573 message_id,
574 stream,
575 last_token_usage,
576 })
577 }
578 .boxed()
579 }
580
581 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
582 None
583 }
584
585 #[cfg(any(test, feature = "test-support"))]
586 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
587 unimplemented!()
588 }
589}
590
591pub trait LanguageModelExt: LanguageModel {
592 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
593 match mode {
594 CompletionMode::Normal => self.max_token_count(),
595 CompletionMode::Max => self
596 .max_token_count_in_burn_mode()
597 .unwrap_or_else(|| self.max_token_count()),
598 }
599 }
600}
601impl LanguageModelExt for dyn LanguageModel {}
602
603pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
604 fn name() -> String;
605 fn description() -> String;
606}
607
608/// An error that occurred when trying to authenticate the language model provider.
609#[derive(Debug, Error)]
610pub enum AuthenticateError {
611 #[error("credentials not found")]
612 CredentialsNotFound,
613 #[error(transparent)]
614 Other(#[from] anyhow::Error),
615}
616
617pub trait LanguageModelProvider: 'static {
618 fn id(&self) -> LanguageModelProviderId;
619 fn name(&self) -> LanguageModelProviderName;
620 fn icon(&self) -> IconName {
621 IconName::ZedAssistant
622 }
623 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
624 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
625 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
626 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
627 Vec::new()
628 }
629 fn is_authenticated(&self, cx: &App) -> bool;
630 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
631 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
632 fn must_accept_terms(&self, _cx: &App) -> bool {
633 false
634 }
635 fn render_accept_terms(
636 &self,
637 _view: LanguageModelProviderTosView,
638 _cx: &mut App,
639 ) -> Option<AnyElement> {
640 None
641 }
642 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
643}
644
645#[derive(PartialEq, Eq)]
646pub enum LanguageModelProviderTosView {
647 /// When there are some past interactions in the Agent Panel.
648 ThreadEmptyState,
649 /// When there are no past interactions in the Agent Panel.
650 ThreadFreshStart,
651 PromptEditorPopup,
652 Configuration,
653}
654
655pub trait LanguageModelProviderState: 'static {
656 type ObservableEntity;
657
658 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
659
660 fn subscribe<T: 'static>(
661 &self,
662 cx: &mut gpui::Context<T>,
663 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
664 ) -> Option<gpui::Subscription> {
665 let entity = self.observable_entity()?;
666 Some(cx.observe(&entity, move |this, _, cx| {
667 callback(this, cx);
668 }))
669 }
670}
671
672#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
673pub struct LanguageModelId(pub SharedString);
674
675#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
676pub struct LanguageModelName(pub SharedString);
677
678#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
679pub struct LanguageModelProviderId(pub SharedString);
680
681#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
682pub struct LanguageModelProviderName(pub SharedString);
683
684impl LanguageModelProviderId {
685 pub const fn new(id: &'static str) -> Self {
686 Self(SharedString::new_static(id))
687 }
688}
689
690impl LanguageModelProviderName {
691 pub const fn new(id: &'static str) -> Self {
692 Self(SharedString::new_static(id))
693 }
694}
695
696impl fmt::Display for LanguageModelProviderId {
697 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
698 write!(f, "{}", self.0)
699 }
700}
701
702impl fmt::Display for LanguageModelProviderName {
703 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
704 write!(f, "{}", self.0)
705 }
706}
707
708impl From<String> for LanguageModelId {
709 fn from(value: String) -> Self {
710 Self(SharedString::from(value))
711 }
712}
713
714impl From<String> for LanguageModelName {
715 fn from(value: String) -> Self {
716 Self(SharedString::from(value))
717 }
718}
719
720impl From<String> for LanguageModelProviderId {
721 fn from(value: String) -> Self {
722 Self(SharedString::from(value))
723 }
724}
725
726impl From<String> for LanguageModelProviderName {
727 fn from(value: String) -> Self {
728 Self(SharedString::from(value))
729 }
730}
731
732#[cfg(test)]
733mod tests {
734 use super::*;
735
736 #[test]
737 fn test_from_cloud_failure_with_upstream_http_error() {
738 let error = LanguageModelCompletionError::from_cloud_failure(
739 String::from("anthropic").into(),
740 "upstream_http_error".to_string(),
741 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
742 None,
743 );
744
745 match error {
746 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
747 assert_eq!(provider.0, "anthropic");
748 }
749 _ => panic!(
750 "Expected ServerOverloaded error for 503 status, got: {:?}",
751 error
752 ),
753 }
754
755 let error = LanguageModelCompletionError::from_cloud_failure(
756 String::from("anthropic").into(),
757 "upstream_http_error".to_string(),
758 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
759 None,
760 );
761
762 match error {
763 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
764 assert_eq!(provider.0, "anthropic");
765 assert_eq!(message, "Internal server error");
766 }
767 _ => panic!(
768 "Expected ApiInternalServerError for 500 status, got: {:?}",
769 error
770 ),
771 }
772 }
773
774 #[test]
775 fn test_from_cloud_failure_with_standard_format() {
776 let error = LanguageModelCompletionError::from_cloud_failure(
777 String::from("anthropic").into(),
778 "upstream_http_503".to_string(),
779 "Service unavailable".to_string(),
780 None,
781 );
782
783 match error {
784 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
785 assert_eq!(provider.0, "anthropic");
786 }
787 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
788 }
789 }
790
791 #[test]
792 fn test_upstream_http_error_connection_timeout() {
793 let error = LanguageModelCompletionError::from_cloud_failure(
794 String::from("anthropic").into(),
795 "upstream_http_error".to_string(),
796 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
797 None,
798 );
799
800 match error {
801 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
802 assert_eq!(provider.0, "anthropic");
803 }
804 _ => panic!(
805 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
806 error
807 ),
808 }
809
810 let error = LanguageModelCompletionError::from_cloud_failure(
811 String::from("anthropic").into(),
812 "upstream_http_error".to_string(),
813 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
814 None,
815 );
816
817 match error {
818 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
819 assert_eq!(provider.0, "anthropic");
820 assert_eq!(
821 message,
822 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
823 );
824 }
825 _ => panic!(
826 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
827 error
828 ),
829 }
830 }
831}