language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use futures::FutureExt;
 15use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 16use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 17use http_client::{StatusCode, http};
 18use icons::IconName;
 19use parking_lot::Mutex;
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize, de::DeserializeOwned};
 22use std::ops::{Add, Sub};
 23use std::str::FromStr;
 24use std::sync::Arc;
 25use std::time::Duration;
 26use std::{fmt, io};
 27use thiserror::Error;
 28use util::serde::is_default;
 29use zed_llm_client::{CompletionMode, CompletionRequestStatus};
 30
 31pub use crate::model::*;
 32pub use crate::rate_limiter::*;
 33pub use crate::registry::*;
 34pub use crate::request::*;
 35pub use crate::role::*;
 36pub use crate::telemetry::*;
 37
 38pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 39    LanguageModelProviderId::new("anthropic");
 40pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 41    LanguageModelProviderName::new("Anthropic");
 42
 43pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 44pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 45    LanguageModelProviderName::new("Google AI");
 46
 47pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 48pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 49    LanguageModelProviderName::new("OpenAI");
 50
 51pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 52pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 53    LanguageModelProviderName::new("Zed");
 54
 55pub fn init(client: Arc<Client>, cx: &mut App) {
 56    init_settings(cx);
 57    RefreshLlmTokenListener::register(client.clone(), cx);
 58}
 59
 60pub fn init_settings(cx: &mut App) {
 61    registry::init(cx);
 62}
 63
 64/// Configuration for caching language model messages.
 65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 66pub struct LanguageModelCacheConfiguration {
 67    pub max_cache_anchors: usize,
 68    pub should_speculate: bool,
 69    pub min_total_token: u64,
 70}
 71
 72/// A completion event from a language model.
 73#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 74pub enum LanguageModelCompletionEvent {
 75    StatusUpdate(CompletionRequestStatus),
 76    Stop(StopReason),
 77    Text(String),
 78    Thinking {
 79        text: String,
 80        signature: Option<String>,
 81    },
 82    RedactedThinking {
 83        data: String,
 84    },
 85    ToolUse(LanguageModelToolUse),
 86    ToolUseJsonParseError {
 87        id: LanguageModelToolUseId,
 88        tool_name: Arc<str>,
 89        raw_input: Arc<str>,
 90        json_parse_error: String,
 91    },
 92    StartMessage {
 93        message_id: String,
 94    },
 95    UsageUpdate(TokenUsage),
 96}
 97
 98#[derive(Error, Debug)]
 99pub enum LanguageModelCompletionError {
100    #[error("prompt too large for context window")]
101    PromptTooLarge { tokens: Option<u64> },
102    #[error("missing {provider} API key")]
103    NoApiKey { provider: LanguageModelProviderName },
104    #[error("{provider}'s API rate limit exceeded")]
105    RateLimitExceeded {
106        provider: LanguageModelProviderName,
107        retry_after: Option<Duration>,
108    },
109    #[error("{provider}'s API servers are overloaded right now")]
110    ServerOverloaded {
111        provider: LanguageModelProviderName,
112        retry_after: Option<Duration>,
113    },
114    #[error("{provider}'s API server reported an internal server error: {message}")]
115    ApiInternalServerError {
116        provider: LanguageModelProviderName,
117        message: String,
118    },
119    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
120    HttpResponseError {
121        provider: LanguageModelProviderName,
122        status_code: StatusCode,
123        message: String,
124    },
125
126    // Client errors
127    #[error("invalid request format to {provider}'s API: {message}")]
128    BadRequestFormat {
129        provider: LanguageModelProviderName,
130        message: String,
131    },
132    #[error("authentication error with {provider}'s API: {message}")]
133    AuthenticationError {
134        provider: LanguageModelProviderName,
135        message: String,
136    },
137    #[error("permission error with {provider}'s API: {message}")]
138    PermissionError {
139        provider: LanguageModelProviderName,
140        message: String,
141    },
142    #[error("language model provider API endpoint not found")]
143    ApiEndpointNotFound { provider: LanguageModelProviderName },
144    #[error("I/O error reading response from {provider}'s API")]
145    ApiReadResponseError {
146        provider: LanguageModelProviderName,
147        #[source]
148        error: io::Error,
149    },
150    #[error("error serializing request to {provider} API")]
151    SerializeRequest {
152        provider: LanguageModelProviderName,
153        #[source]
154        error: serde_json::Error,
155    },
156    #[error("error building request body to {provider} API")]
157    BuildRequestBody {
158        provider: LanguageModelProviderName,
159        #[source]
160        error: http::Error,
161    },
162    #[error("error sending HTTP request to {provider} API")]
163    HttpSend {
164        provider: LanguageModelProviderName,
165        #[source]
166        error: anyhow::Error,
167    },
168    #[error("error deserializing {provider} API response")]
169    DeserializeResponse {
170        provider: LanguageModelProviderName,
171        #[source]
172        error: serde_json::Error,
173    },
174
175    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
176    #[error(transparent)]
177    Other(#[from] anyhow::Error),
178}
179
180impl LanguageModelCompletionError {
181    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
182        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
183        let upstream_status = error_json
184            .get("upstream_status")
185            .and_then(|v| v.as_u64())
186            .and_then(|status| u16::try_from(status).ok())
187            .and_then(|status| StatusCode::from_u16(status).ok())?;
188        let inner_message = error_json
189            .get("message")
190            .and_then(|v| v.as_str())
191            .unwrap_or(message)
192            .to_string();
193        Some((upstream_status, inner_message))
194    }
195
196    pub fn from_cloud_failure(
197        upstream_provider: LanguageModelProviderName,
198        code: String,
199        message: String,
200        retry_after: Option<Duration>,
201    ) -> Self {
202        if let Some(tokens) = parse_prompt_too_long(&message) {
203            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
204            // to be reported. This is a temporary workaround to handle this in the case where the
205            // token limit has been exceeded.
206            Self::PromptTooLarge {
207                tokens: Some(tokens),
208            }
209        } else if code == "upstream_http_error" {
210            if let Some((upstream_status, inner_message)) =
211                Self::parse_upstream_error_json(&message)
212            {
213                return Self::from_http_status(
214                    upstream_provider,
215                    upstream_status,
216                    inner_message,
217                    retry_after,
218                );
219            }
220            anyhow!("completion request failed, code: {code}, message: {message}").into()
221        } else if let Some(status_code) = code
222            .strip_prefix("upstream_http_")
223            .and_then(|code| StatusCode::from_str(code).ok())
224        {
225            Self::from_http_status(upstream_provider, status_code, message, retry_after)
226        } else if let Some(status_code) = code
227            .strip_prefix("http_")
228            .and_then(|code| StatusCode::from_str(code).ok())
229        {
230            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
231        } else {
232            anyhow!("completion request failed, code: {code}, message: {message}").into()
233        }
234    }
235
236    pub fn from_http_status(
237        provider: LanguageModelProviderName,
238        status_code: StatusCode,
239        message: String,
240        retry_after: Option<Duration>,
241    ) -> Self {
242        match status_code {
243            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
244            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
245            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
246            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
247            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
248                tokens: parse_prompt_too_long(&message),
249            },
250            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
251                provider,
252                retry_after,
253            },
254            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
255            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
256                provider,
257                retry_after,
258            },
259            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
260                provider,
261                retry_after,
262            },
263            _ => Self::HttpResponseError {
264                provider,
265                status_code,
266                message,
267            },
268        }
269    }
270}
271
272impl From<AnthropicError> for LanguageModelCompletionError {
273    fn from(error: AnthropicError) -> Self {
274        let provider = ANTHROPIC_PROVIDER_NAME;
275        match error {
276            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
277            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
278            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
279            AnthropicError::DeserializeResponse(error) => {
280                Self::DeserializeResponse { provider, error }
281            }
282            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
283            AnthropicError::HttpResponseError {
284                status_code,
285                message,
286            } => Self::HttpResponseError {
287                provider,
288                status_code,
289                message,
290            },
291            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
292                provider,
293                retry_after: Some(retry_after),
294            },
295            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
296                provider,
297                retry_after: retry_after,
298            },
299            AnthropicError::ApiError(api_error) => api_error.into(),
300        }
301    }
302}
303
304impl From<anthropic::ApiError> for LanguageModelCompletionError {
305    fn from(error: anthropic::ApiError) -> Self {
306        use anthropic::ApiErrorCode::*;
307        let provider = ANTHROPIC_PROVIDER_NAME;
308        match error.code() {
309            Some(code) => match code {
310                InvalidRequestError => Self::BadRequestFormat {
311                    provider,
312                    message: error.message,
313                },
314                AuthenticationError => Self::AuthenticationError {
315                    provider,
316                    message: error.message,
317                },
318                PermissionError => Self::PermissionError {
319                    provider,
320                    message: error.message,
321                },
322                NotFoundError => Self::ApiEndpointNotFound { provider },
323                RequestTooLarge => Self::PromptTooLarge {
324                    tokens: parse_prompt_too_long(&error.message),
325                },
326                RateLimitError => Self::RateLimitExceeded {
327                    provider,
328                    retry_after: None,
329                },
330                ApiError => Self::ApiInternalServerError {
331                    provider,
332                    message: error.message,
333                },
334                OverloadedError => Self::ServerOverloaded {
335                    provider,
336                    retry_after: None,
337                },
338            },
339            None => Self::Other(error.into()),
340        }
341    }
342}
343
344/// Indicates the format used to define the input schema for a language model tool.
345#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
346pub enum LanguageModelToolSchemaFormat {
347    /// A JSON schema, see https://json-schema.org
348    JsonSchema,
349    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
350    JsonSchemaSubset,
351}
352
353#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
354#[serde(rename_all = "snake_case")]
355pub enum StopReason {
356    EndTurn,
357    MaxTokens,
358    ToolUse,
359    Refusal,
360}
361
362#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
363pub struct TokenUsage {
364    #[serde(default, skip_serializing_if = "is_default")]
365    pub input_tokens: u64,
366    #[serde(default, skip_serializing_if = "is_default")]
367    pub output_tokens: u64,
368    #[serde(default, skip_serializing_if = "is_default")]
369    pub cache_creation_input_tokens: u64,
370    #[serde(default, skip_serializing_if = "is_default")]
371    pub cache_read_input_tokens: u64,
372}
373
374impl TokenUsage {
375    pub fn total_tokens(&self) -> u64 {
376        self.input_tokens
377            + self.output_tokens
378            + self.cache_read_input_tokens
379            + self.cache_creation_input_tokens
380    }
381}
382
383impl Add<TokenUsage> for TokenUsage {
384    type Output = Self;
385
386    fn add(self, other: Self) -> Self {
387        Self {
388            input_tokens: self.input_tokens + other.input_tokens,
389            output_tokens: self.output_tokens + other.output_tokens,
390            cache_creation_input_tokens: self.cache_creation_input_tokens
391                + other.cache_creation_input_tokens,
392            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
393        }
394    }
395}
396
397impl Sub<TokenUsage> for TokenUsage {
398    type Output = Self;
399
400    fn sub(self, other: Self) -> Self {
401        Self {
402            input_tokens: self.input_tokens - other.input_tokens,
403            output_tokens: self.output_tokens - other.output_tokens,
404            cache_creation_input_tokens: self.cache_creation_input_tokens
405                - other.cache_creation_input_tokens,
406            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
407        }
408    }
409}
410
411#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
412pub struct LanguageModelToolUseId(Arc<str>);
413
414impl fmt::Display for LanguageModelToolUseId {
415    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416        write!(f, "{}", self.0)
417    }
418}
419
420impl<T> From<T> for LanguageModelToolUseId
421where
422    T: Into<Arc<str>>,
423{
424    fn from(value: T) -> Self {
425        Self(value.into())
426    }
427}
428
429#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
430pub struct LanguageModelToolUse {
431    pub id: LanguageModelToolUseId,
432    pub name: Arc<str>,
433    pub raw_input: String,
434    pub input: serde_json::Value,
435    pub is_input_complete: bool,
436}
437
438pub struct LanguageModelTextStream {
439    pub message_id: Option<String>,
440    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
441    // Has complete token usage after the stream has finished
442    pub last_token_usage: Arc<Mutex<TokenUsage>>,
443}
444
445impl Default for LanguageModelTextStream {
446    fn default() -> Self {
447        Self {
448            message_id: None,
449            stream: Box::pin(futures::stream::empty()),
450            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
451        }
452    }
453}
454
455pub trait LanguageModel: Send + Sync {
456    fn id(&self) -> LanguageModelId;
457    fn name(&self) -> LanguageModelName;
458    fn provider_id(&self) -> LanguageModelProviderId;
459    fn provider_name(&self) -> LanguageModelProviderName;
460    fn upstream_provider_id(&self) -> LanguageModelProviderId {
461        self.provider_id()
462    }
463    fn upstream_provider_name(&self) -> LanguageModelProviderName {
464        self.provider_name()
465    }
466
467    fn telemetry_id(&self) -> String;
468
469    fn api_key(&self, _cx: &App) -> Option<String> {
470        None
471    }
472
473    /// Whether this model supports images
474    fn supports_images(&self) -> bool;
475
476    /// Whether this model supports tools.
477    fn supports_tools(&self) -> bool;
478
479    /// Whether this model supports choosing which tool to use.
480    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
481
482    /// Returns whether this model supports "burn mode";
483    fn supports_burn_mode(&self) -> bool {
484        false
485    }
486
487    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
488        LanguageModelToolSchemaFormat::JsonSchema
489    }
490
491    fn max_token_count(&self) -> u64;
492    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
493    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
494        None
495    }
496    fn max_output_tokens(&self) -> Option<u64> {
497        None
498    }
499
500    fn count_tokens(
501        &self,
502        request: LanguageModelRequest,
503        cx: &App,
504    ) -> BoxFuture<'static, Result<u64>>;
505
506    fn stream_completion(
507        &self,
508        request: LanguageModelRequest,
509        cx: &AsyncApp,
510    ) -> BoxFuture<
511        'static,
512        Result<
513            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
514            LanguageModelCompletionError,
515        >,
516    >;
517
518    fn stream_completion_text(
519        &self,
520        request: LanguageModelRequest,
521        cx: &AsyncApp,
522    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
523        let future = self.stream_completion(request, cx);
524
525        async move {
526            let events = future.await?;
527            let mut events = events.fuse();
528            let mut message_id = None;
529            let mut first_item_text = None;
530            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
531
532            if let Some(first_event) = events.next().await {
533                match first_event {
534                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
535                        message_id = Some(id.clone());
536                    }
537                    Ok(LanguageModelCompletionEvent::Text(text)) => {
538                        first_item_text = Some(text);
539                    }
540                    _ => (),
541                }
542            }
543
544            let stream = futures::stream::iter(first_item_text.map(Ok))
545                .chain(events.filter_map({
546                    let last_token_usage = last_token_usage.clone();
547                    move |result| {
548                        let last_token_usage = last_token_usage.clone();
549                        async move {
550                            match result {
551                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
552                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
553                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
554                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
555                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
556                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
557                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
558                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
559                                    ..
560                                }) => None,
561                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
562                                    *last_token_usage.lock() = token_usage;
563                                    None
564                                }
565                                Err(err) => Some(Err(err)),
566                            }
567                        }
568                    }
569                }))
570                .boxed();
571
572            Ok(LanguageModelTextStream {
573                message_id,
574                stream,
575                last_token_usage,
576            })
577        }
578        .boxed()
579    }
580
581    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
582        None
583    }
584
585    #[cfg(any(test, feature = "test-support"))]
586    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
587        unimplemented!()
588    }
589}
590
591pub trait LanguageModelExt: LanguageModel {
592    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
593        match mode {
594            CompletionMode::Normal => self.max_token_count(),
595            CompletionMode::Max => self
596                .max_token_count_in_burn_mode()
597                .unwrap_or_else(|| self.max_token_count()),
598        }
599    }
600}
601impl LanguageModelExt for dyn LanguageModel {}
602
603pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
604    fn name() -> String;
605    fn description() -> String;
606}
607
608/// An error that occurred when trying to authenticate the language model provider.
609#[derive(Debug, Error)]
610pub enum AuthenticateError {
611    #[error("credentials not found")]
612    CredentialsNotFound,
613    #[error(transparent)]
614    Other(#[from] anyhow::Error),
615}
616
617pub trait LanguageModelProvider: 'static {
618    fn id(&self) -> LanguageModelProviderId;
619    fn name(&self) -> LanguageModelProviderName;
620    fn icon(&self) -> IconName {
621        IconName::ZedAssistant
622    }
623    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
624    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
625    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
626    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
627        Vec::new()
628    }
629    fn is_authenticated(&self, cx: &App) -> bool;
630    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
631    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
632    fn must_accept_terms(&self, _cx: &App) -> bool {
633        false
634    }
635    fn render_accept_terms(
636        &self,
637        _view: LanguageModelProviderTosView,
638        _cx: &mut App,
639    ) -> Option<AnyElement> {
640        None
641    }
642    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
643}
644
645#[derive(PartialEq, Eq)]
646pub enum LanguageModelProviderTosView {
647    /// When there are some past interactions in the Agent Panel.
648    ThreadEmptyState,
649    /// When there are no past interactions in the Agent Panel.
650    ThreadFreshStart,
651    PromptEditorPopup,
652    Configuration,
653}
654
655pub trait LanguageModelProviderState: 'static {
656    type ObservableEntity;
657
658    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
659
660    fn subscribe<T: 'static>(
661        &self,
662        cx: &mut gpui::Context<T>,
663        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
664    ) -> Option<gpui::Subscription> {
665        let entity = self.observable_entity()?;
666        Some(cx.observe(&entity, move |this, _, cx| {
667            callback(this, cx);
668        }))
669    }
670}
671
672#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
673pub struct LanguageModelId(pub SharedString);
674
675#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
676pub struct LanguageModelName(pub SharedString);
677
678#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
679pub struct LanguageModelProviderId(pub SharedString);
680
681#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
682pub struct LanguageModelProviderName(pub SharedString);
683
684impl LanguageModelProviderId {
685    pub const fn new(id: &'static str) -> Self {
686        Self(SharedString::new_static(id))
687    }
688}
689
690impl LanguageModelProviderName {
691    pub const fn new(id: &'static str) -> Self {
692        Self(SharedString::new_static(id))
693    }
694}
695
696impl fmt::Display for LanguageModelProviderId {
697    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
698        write!(f, "{}", self.0)
699    }
700}
701
702impl fmt::Display for LanguageModelProviderName {
703    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
704        write!(f, "{}", self.0)
705    }
706}
707
708impl From<String> for LanguageModelId {
709    fn from(value: String) -> Self {
710        Self(SharedString::from(value))
711    }
712}
713
714impl From<String> for LanguageModelName {
715    fn from(value: String) -> Self {
716        Self(SharedString::from(value))
717    }
718}
719
720impl From<String> for LanguageModelProviderId {
721    fn from(value: String) -> Self {
722        Self(SharedString::from(value))
723    }
724}
725
726impl From<String> for LanguageModelProviderName {
727    fn from(value: String) -> Self {
728        Self(SharedString::from(value))
729    }
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735
736    #[test]
737    fn test_from_cloud_failure_with_upstream_http_error() {
738        let error = LanguageModelCompletionError::from_cloud_failure(
739            String::from("anthropic").into(),
740            "upstream_http_error".to_string(),
741            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
742            None,
743        );
744
745        match error {
746            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
747                assert_eq!(provider.0, "anthropic");
748            }
749            _ => panic!(
750                "Expected ServerOverloaded error for 503 status, got: {:?}",
751                error
752            ),
753        }
754
755        let error = LanguageModelCompletionError::from_cloud_failure(
756            String::from("anthropic").into(),
757            "upstream_http_error".to_string(),
758            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
759            None,
760        );
761
762        match error {
763            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
764                assert_eq!(provider.0, "anthropic");
765                assert_eq!(message, "Internal server error");
766            }
767            _ => panic!(
768                "Expected ApiInternalServerError for 500 status, got: {:?}",
769                error
770            ),
771        }
772    }
773
774    #[test]
775    fn test_from_cloud_failure_with_standard_format() {
776        let error = LanguageModelCompletionError::from_cloud_failure(
777            String::from("anthropic").into(),
778            "upstream_http_503".to_string(),
779            "Service unavailable".to_string(),
780            None,
781        );
782
783        match error {
784            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
785                assert_eq!(provider.0, "anthropic");
786            }
787            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
788        }
789    }
790
791    #[test]
792    fn test_upstream_http_error_connection_timeout() {
793        let error = LanguageModelCompletionError::from_cloud_failure(
794            String::from("anthropic").into(),
795            "upstream_http_error".to_string(),
796            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
797            None,
798        );
799
800        match error {
801            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
802                assert_eq!(provider.0, "anthropic");
803            }
804            _ => panic!(
805                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
806                error
807            ),
808        }
809
810        let error = LanguageModelCompletionError::from_cloud_failure(
811            String::from("anthropic").into(),
812            "upstream_http_error".to_string(),
813            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
814            None,
815        );
816
817        match error {
818            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
819                assert_eq!(provider.0, "anthropic");
820                assert_eq!(
821                    message,
822                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
823                );
824            }
825            _ => panic!(
826                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
827                error
828            ),
829        }
830    }
831}