language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 15use futures::FutureExt;
 16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 17use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 18use http_client::{StatusCode, http};
 19use icons::IconName;
 20use parking_lot::Mutex;
 21use schemars::JsonSchema;
 22use serde::{Deserialize, Serialize, de::DeserializeOwned};
 23use std::any::Any;
 24use std::ops::{Add, Sub};
 25use std::str::FromStr;
 26use std::sync::Arc;
 27use std::time::Duration;
 28use std::{fmt, io};
 29use thiserror::Error;
 30use util::serde::is_default;
 31
 32pub use crate::model::*;
 33pub use crate::rate_limiter::*;
 34pub use crate::registry::*;
 35pub use crate::request::*;
 36pub use crate::role::*;
 37pub use crate::telemetry::*;
 38
 39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 40    LanguageModelProviderId::new("anthropic");
 41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 42    LanguageModelProviderName::new("Anthropic");
 43
 44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 46    LanguageModelProviderName::new("Google AI");
 47
 48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 50    LanguageModelProviderName::new("OpenAI");
 51
 52pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 53pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 54    LanguageModelProviderName::new("Zed");
 55
 56pub fn init(client: Arc<Client>, cx: &mut App) {
 57    init_settings(cx);
 58    RefreshLlmTokenListener::register(client.clone(), cx);
 59}
 60
 61pub fn init_settings(cx: &mut App) {
 62    registry::init(cx);
 63}
 64
 65/// Configuration for caching language model messages.
 66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 67pub struct LanguageModelCacheConfiguration {
 68    pub max_cache_anchors: usize,
 69    pub should_speculate: bool,
 70    pub min_total_token: u64,
 71}
 72
 73/// A completion event from a language model.
 74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 75pub enum LanguageModelCompletionEvent {
 76    StatusUpdate(CompletionRequestStatus),
 77    Stop(StopReason),
 78    Text(String),
 79    Thinking {
 80        text: String,
 81        signature: Option<String>,
 82    },
 83    RedactedThinking {
 84        data: String,
 85    },
 86    ToolUse(LanguageModelToolUse),
 87    ToolUseJsonParseError {
 88        id: LanguageModelToolUseId,
 89        tool_name: Arc<str>,
 90        raw_input: Arc<str>,
 91        json_parse_error: String,
 92    },
 93    StartMessage {
 94        message_id: String,
 95    },
 96    UsageUpdate(TokenUsage),
 97}
 98
 99#[derive(Error, Debug)]
100pub enum LanguageModelCompletionError {
101    #[error("prompt too large for context window")]
102    PromptTooLarge { tokens: Option<u64> },
103    #[error("missing {provider} API key")]
104    NoApiKey { provider: LanguageModelProviderName },
105    #[error("{provider}'s API rate limit exceeded")]
106    RateLimitExceeded {
107        provider: LanguageModelProviderName,
108        retry_after: Option<Duration>,
109    },
110    #[error("{provider}'s API servers are overloaded right now")]
111    ServerOverloaded {
112        provider: LanguageModelProviderName,
113        retry_after: Option<Duration>,
114    },
115    #[error("{provider}'s API server reported an internal server error: {message}")]
116    ApiInternalServerError {
117        provider: LanguageModelProviderName,
118        message: String,
119    },
120    #[error("{message}")]
121    UpstreamProviderError {
122        message: String,
123        status: StatusCode,
124        retry_after: Option<Duration>,
125    },
126    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
127    HttpResponseError {
128        provider: LanguageModelProviderName,
129        status_code: StatusCode,
130        message: String,
131    },
132
133    // Client errors
134    #[error("invalid request format to {provider}'s API: {message}")]
135    BadRequestFormat {
136        provider: LanguageModelProviderName,
137        message: String,
138    },
139    #[error("authentication error with {provider}'s API: {message}")]
140    AuthenticationError {
141        provider: LanguageModelProviderName,
142        message: String,
143    },
144    #[error("permission error with {provider}'s API: {message}")]
145    PermissionError {
146        provider: LanguageModelProviderName,
147        message: String,
148    },
149    #[error("language model provider API endpoint not found")]
150    ApiEndpointNotFound { provider: LanguageModelProviderName },
151    #[error("I/O error reading response from {provider}'s API")]
152    ApiReadResponseError {
153        provider: LanguageModelProviderName,
154        #[source]
155        error: io::Error,
156    },
157    #[error("error serializing request to {provider} API")]
158    SerializeRequest {
159        provider: LanguageModelProviderName,
160        #[source]
161        error: serde_json::Error,
162    },
163    #[error("error building request body to {provider} API")]
164    BuildRequestBody {
165        provider: LanguageModelProviderName,
166        #[source]
167        error: http::Error,
168    },
169    #[error("error sending HTTP request to {provider} API")]
170    HttpSend {
171        provider: LanguageModelProviderName,
172        #[source]
173        error: anyhow::Error,
174    },
175    #[error("error deserializing {provider} API response")]
176    DeserializeResponse {
177        provider: LanguageModelProviderName,
178        #[source]
179        error: serde_json::Error,
180    },
181
182    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
183    #[error(transparent)]
184    Other(#[from] anyhow::Error),
185}
186
187impl LanguageModelCompletionError {
188    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
189        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
190        let upstream_status = error_json
191            .get("upstream_status")
192            .and_then(|v| v.as_u64())
193            .and_then(|status| u16::try_from(status).ok())
194            .and_then(|status| StatusCode::from_u16(status).ok())?;
195        let inner_message = error_json
196            .get("message")
197            .and_then(|v| v.as_str())
198            .unwrap_or(message)
199            .to_string();
200        Some((upstream_status, inner_message))
201    }
202
203    pub fn from_cloud_failure(
204        upstream_provider: LanguageModelProviderName,
205        code: String,
206        message: String,
207        retry_after: Option<Duration>,
208    ) -> Self {
209        if let Some(tokens) = parse_prompt_too_long(&message) {
210            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
211            // to be reported. This is a temporary workaround to handle this in the case where the
212            // token limit has been exceeded.
213            Self::PromptTooLarge {
214                tokens: Some(tokens),
215            }
216        } else if code == "upstream_http_error" {
217            if let Some((upstream_status, inner_message)) =
218                Self::parse_upstream_error_json(&message)
219            {
220                return Self::from_http_status(
221                    upstream_provider,
222                    upstream_status,
223                    inner_message,
224                    retry_after,
225                );
226            }
227            anyhow!("completion request failed, code: {code}, message: {message}").into()
228        } else if let Some(status_code) = code
229            .strip_prefix("upstream_http_")
230            .and_then(|code| StatusCode::from_str(code).ok())
231        {
232            Self::from_http_status(upstream_provider, status_code, message, retry_after)
233        } else if let Some(status_code) = code
234            .strip_prefix("http_")
235            .and_then(|code| StatusCode::from_str(code).ok())
236        {
237            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
238        } else {
239            anyhow!("completion request failed, code: {code}, message: {message}").into()
240        }
241    }
242
243    pub fn from_http_status(
244        provider: LanguageModelProviderName,
245        status_code: StatusCode,
246        message: String,
247        retry_after: Option<Duration>,
248    ) -> Self {
249        match status_code {
250            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
251            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
252            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
253            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
254            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
255                tokens: parse_prompt_too_long(&message),
256            },
257            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
258                provider,
259                retry_after,
260            },
261            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
262            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
263                provider,
264                retry_after,
265            },
266            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
267                provider,
268                retry_after,
269            },
270            _ => Self::HttpResponseError {
271                provider,
272                status_code,
273                message,
274            },
275        }
276    }
277}
278
279impl From<AnthropicError> for LanguageModelCompletionError {
280    fn from(error: AnthropicError) -> Self {
281        let provider = ANTHROPIC_PROVIDER_NAME;
282        match error {
283            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
284            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
285            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
286            AnthropicError::DeserializeResponse(error) => {
287                Self::DeserializeResponse { provider, error }
288            }
289            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
290            AnthropicError::HttpResponseError {
291                status_code,
292                message,
293            } => Self::HttpResponseError {
294                provider,
295                status_code,
296                message,
297            },
298            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
299                provider,
300                retry_after: Some(retry_after),
301            },
302            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
303                provider,
304                retry_after: retry_after,
305            },
306            AnthropicError::ApiError(api_error) => api_error.into(),
307        }
308    }
309}
310
311impl From<anthropic::ApiError> for LanguageModelCompletionError {
312    fn from(error: anthropic::ApiError) -> Self {
313        use anthropic::ApiErrorCode::*;
314        let provider = ANTHROPIC_PROVIDER_NAME;
315        match error.code() {
316            Some(code) => match code {
317                InvalidRequestError => Self::BadRequestFormat {
318                    provider,
319                    message: error.message,
320                },
321                AuthenticationError => Self::AuthenticationError {
322                    provider,
323                    message: error.message,
324                },
325                PermissionError => Self::PermissionError {
326                    provider,
327                    message: error.message,
328                },
329                NotFoundError => Self::ApiEndpointNotFound { provider },
330                RequestTooLarge => Self::PromptTooLarge {
331                    tokens: parse_prompt_too_long(&error.message),
332                },
333                RateLimitError => Self::RateLimitExceeded {
334                    provider,
335                    retry_after: None,
336                },
337                ApiError => Self::ApiInternalServerError {
338                    provider,
339                    message: error.message,
340                },
341                OverloadedError => Self::ServerOverloaded {
342                    provider,
343                    retry_after: None,
344                },
345            },
346            None => Self::Other(error.into()),
347        }
348    }
349}
350
351/// Indicates the format used to define the input schema for a language model tool.
352#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
353pub enum LanguageModelToolSchemaFormat {
354    /// A JSON schema, see https://json-schema.org
355    JsonSchema,
356    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
357    JsonSchemaSubset,
358}
359
360#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
361#[serde(rename_all = "snake_case")]
362pub enum StopReason {
363    EndTurn,
364    MaxTokens,
365    ToolUse,
366    Refusal,
367}
368
369#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
370pub struct TokenUsage {
371    #[serde(default, skip_serializing_if = "is_default")]
372    pub input_tokens: u64,
373    #[serde(default, skip_serializing_if = "is_default")]
374    pub output_tokens: u64,
375    #[serde(default, skip_serializing_if = "is_default")]
376    pub cache_creation_input_tokens: u64,
377    #[serde(default, skip_serializing_if = "is_default")]
378    pub cache_read_input_tokens: u64,
379}
380
381impl TokenUsage {
382    pub fn total_tokens(&self) -> u64 {
383        self.input_tokens
384            + self.output_tokens
385            + self.cache_read_input_tokens
386            + self.cache_creation_input_tokens
387    }
388}
389
390impl Add<TokenUsage> for TokenUsage {
391    type Output = Self;
392
393    fn add(self, other: Self) -> Self {
394        Self {
395            input_tokens: self.input_tokens + other.input_tokens,
396            output_tokens: self.output_tokens + other.output_tokens,
397            cache_creation_input_tokens: self.cache_creation_input_tokens
398                + other.cache_creation_input_tokens,
399            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
400        }
401    }
402}
403
404impl Sub<TokenUsage> for TokenUsage {
405    type Output = Self;
406
407    fn sub(self, other: Self) -> Self {
408        Self {
409            input_tokens: self.input_tokens - other.input_tokens,
410            output_tokens: self.output_tokens - other.output_tokens,
411            cache_creation_input_tokens: self.cache_creation_input_tokens
412                - other.cache_creation_input_tokens,
413            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
414        }
415    }
416}
417
418#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
419pub struct LanguageModelToolUseId(Arc<str>);
420
421impl fmt::Display for LanguageModelToolUseId {
422    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
423        write!(f, "{}", self.0)
424    }
425}
426
427impl<T> From<T> for LanguageModelToolUseId
428where
429    T: Into<Arc<str>>,
430{
431    fn from(value: T) -> Self {
432        Self(value.into())
433    }
434}
435
436#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
437pub struct LanguageModelToolUse {
438    pub id: LanguageModelToolUseId,
439    pub name: Arc<str>,
440    pub raw_input: String,
441    pub input: serde_json::Value,
442    pub is_input_complete: bool,
443}
444
445pub struct LanguageModelTextStream {
446    pub message_id: Option<String>,
447    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
448    // Has complete token usage after the stream has finished
449    pub last_token_usage: Arc<Mutex<TokenUsage>>,
450}
451
452impl Default for LanguageModelTextStream {
453    fn default() -> Self {
454        Self {
455            message_id: None,
456            stream: Box::pin(futures::stream::empty()),
457            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
458        }
459    }
460}
461
462pub trait LanguageModel: Send + Sync {
463    fn id(&self) -> LanguageModelId;
464    fn name(&self) -> LanguageModelName;
465    fn provider_id(&self) -> LanguageModelProviderId;
466    fn provider_name(&self) -> LanguageModelProviderName;
467    fn upstream_provider_id(&self) -> LanguageModelProviderId {
468        self.provider_id()
469    }
470    fn upstream_provider_name(&self) -> LanguageModelProviderName {
471        self.provider_name()
472    }
473
474    fn telemetry_id(&self) -> String;
475
476    fn api_key(&self, _cx: &App) -> Option<String> {
477        None
478    }
479
480    /// Whether this model supports images
481    fn supports_images(&self) -> bool;
482
483    /// Whether this model supports tools.
484    fn supports_tools(&self) -> bool;
485
486    /// Whether this model supports choosing which tool to use.
487    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
488
489    /// Returns whether this model supports "burn mode";
490    fn supports_burn_mode(&self) -> bool {
491        false
492    }
493
494    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
495        LanguageModelToolSchemaFormat::JsonSchema
496    }
497
498    fn max_token_count(&self) -> u64;
499    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
500    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
501        None
502    }
503    fn max_output_tokens(&self) -> Option<u64> {
504        None
505    }
506
507    fn count_tokens(
508        &self,
509        request: LanguageModelRequest,
510        cx: &App,
511    ) -> BoxFuture<'static, Result<u64>>;
512
513    fn stream_completion(
514        &self,
515        request: LanguageModelRequest,
516        cx: &AsyncApp,
517    ) -> BoxFuture<
518        'static,
519        Result<
520            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
521            LanguageModelCompletionError,
522        >,
523    >;
524
525    fn stream_completion_text(
526        &self,
527        request: LanguageModelRequest,
528        cx: &AsyncApp,
529    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
530        let future = self.stream_completion(request, cx);
531
532        async move {
533            let events = future.await?;
534            let mut events = events.fuse();
535            let mut message_id = None;
536            let mut first_item_text = None;
537            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
538
539            if let Some(first_event) = events.next().await {
540                match first_event {
541                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
542                        message_id = Some(id.clone());
543                    }
544                    Ok(LanguageModelCompletionEvent::Text(text)) => {
545                        first_item_text = Some(text);
546                    }
547                    _ => (),
548                }
549            }
550
551            let stream = futures::stream::iter(first_item_text.map(Ok))
552                .chain(events.filter_map({
553                    let last_token_usage = last_token_usage.clone();
554                    move |result| {
555                        let last_token_usage = last_token_usage.clone();
556                        async move {
557                            match result {
558                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
559                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
560                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
561                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
562                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
563                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
564                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
565                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
566                                    ..
567                                }) => None,
568                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
569                                    *last_token_usage.lock() = token_usage;
570                                    None
571                                }
572                                Err(err) => Some(Err(err)),
573                            }
574                        }
575                    }
576                }))
577                .boxed();
578
579            Ok(LanguageModelTextStream {
580                message_id,
581                stream,
582                last_token_usage,
583            })
584        }
585        .boxed()
586    }
587
588    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
589        None
590    }
591
592    #[cfg(any(test, feature = "test-support"))]
593    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
594        unimplemented!()
595    }
596}
597
598pub trait LanguageModelExt: LanguageModel {
599    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
600        match mode {
601            CompletionMode::Normal => self.max_token_count(),
602            CompletionMode::Max => self
603                .max_token_count_in_burn_mode()
604                .unwrap_or_else(|| self.max_token_count()),
605        }
606    }
607}
608impl LanguageModelExt for dyn LanguageModel {}
609
610pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
611    fn name() -> String;
612    fn description() -> String;
613}
614
615/// An error that occurred when trying to authenticate the language model provider.
616#[derive(Debug, Error)]
617pub enum AuthenticateError {
618    #[error("credentials not found")]
619    CredentialsNotFound,
620    #[error(transparent)]
621    Other(#[from] anyhow::Error),
622}
623
624pub trait LanguageModelProvider: Any + Send + Sync {
625    fn id(&self) -> LanguageModelProviderId;
626    fn name(&self) -> LanguageModelProviderName;
627    fn icon(&self) -> IconName {
628        IconName::ZedAssistant
629    }
630    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
631    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
632    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
633    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
634        Vec::new()
635    }
636    fn is_authenticated(&self, cx: &App) -> bool;
637    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
638    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
639    fn must_accept_terms(&self, _cx: &App) -> bool {
640        false
641    }
642    fn render_accept_terms(
643        &self,
644        _view: LanguageModelProviderTosView,
645        _cx: &mut App,
646    ) -> Option<AnyElement> {
647        None
648    }
649    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
650}
651
652#[derive(PartialEq, Eq)]
653pub enum LanguageModelProviderTosView {
654    /// When there are some past interactions in the Agent Panel.
655    ThreadEmptyState,
656    /// When there are no past interactions in the Agent Panel.
657    ThreadFreshStart,
658    TextThreadPopup,
659    Configuration,
660}
661
662pub trait LanguageModelProviderState: 'static {
663    type ObservableEntity;
664
665    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
666
667    fn subscribe<T: 'static>(
668        &self,
669        cx: &mut gpui::Context<T>,
670        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
671    ) -> Option<gpui::Subscription> {
672        let entity = self.observable_entity()?;
673        Some(cx.observe(&entity, move |this, _, cx| {
674            callback(this, cx);
675        }))
676    }
677}
678
679#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
680pub struct LanguageModelId(pub SharedString);
681
682#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
683pub struct LanguageModelName(pub SharedString);
684
685#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
686pub struct LanguageModelProviderId(pub SharedString);
687
688#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
689pub struct LanguageModelProviderName(pub SharedString);
690
691impl LanguageModelProviderId {
692    pub const fn new(id: &'static str) -> Self {
693        Self(SharedString::new_static(id))
694    }
695}
696
697impl LanguageModelProviderName {
698    pub const fn new(id: &'static str) -> Self {
699        Self(SharedString::new_static(id))
700    }
701}
702
703impl fmt::Display for LanguageModelProviderId {
704    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
705        write!(f, "{}", self.0)
706    }
707}
708
709impl fmt::Display for LanguageModelProviderName {
710    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
711        write!(f, "{}", self.0)
712    }
713}
714
715impl From<String> for LanguageModelId {
716    fn from(value: String) -> Self {
717        Self(SharedString::from(value))
718    }
719}
720
721impl From<String> for LanguageModelName {
722    fn from(value: String) -> Self {
723        Self(SharedString::from(value))
724    }
725}
726
727impl From<String> for LanguageModelProviderId {
728    fn from(value: String) -> Self {
729        Self(SharedString::from(value))
730    }
731}
732
733impl From<String> for LanguageModelProviderName {
734    fn from(value: String) -> Self {
735        Self(SharedString::from(value))
736    }
737}
738
739impl From<Arc<str>> for LanguageModelProviderId {
740    fn from(value: Arc<str>) -> Self {
741        Self(SharedString::from(value))
742    }
743}
744
745impl From<Arc<str>> for LanguageModelProviderName {
746    fn from(value: Arc<str>) -> Self {
747        Self(SharedString::from(value))
748    }
749}
750
751#[cfg(test)]
752mod tests {
753    use super::*;
754
755    #[test]
756    fn test_from_cloud_failure_with_upstream_http_error() {
757        let error = LanguageModelCompletionError::from_cloud_failure(
758            String::from("anthropic").into(),
759            "upstream_http_error".to_string(),
760            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
761            None,
762        );
763
764        match error {
765            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
766                assert_eq!(provider.0, "anthropic");
767            }
768            _ => panic!(
769                "Expected ServerOverloaded error for 503 status, got: {:?}",
770                error
771            ),
772        }
773
774        let error = LanguageModelCompletionError::from_cloud_failure(
775            String::from("anthropic").into(),
776            "upstream_http_error".to_string(),
777            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
778            None,
779        );
780
781        match error {
782            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
783                assert_eq!(provider.0, "anthropic");
784                assert_eq!(message, "Internal server error");
785            }
786            _ => panic!(
787                "Expected ApiInternalServerError for 500 status, got: {:?}",
788                error
789            ),
790        }
791    }
792
793    #[test]
794    fn test_from_cloud_failure_with_standard_format() {
795        let error = LanguageModelCompletionError::from_cloud_failure(
796            String::from("anthropic").into(),
797            "upstream_http_503".to_string(),
798            "Service unavailable".to_string(),
799            None,
800        );
801
802        match error {
803            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
804                assert_eq!(provider.0, "anthropic");
805            }
806            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
807        }
808    }
809
810    #[test]
811    fn test_upstream_http_error_connection_timeout() {
812        let error = LanguageModelCompletionError::from_cloud_failure(
813            String::from("anthropic").into(),
814            "upstream_http_error".to_string(),
815            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
816            None,
817        );
818
819        match error {
820            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
821                assert_eq!(provider.0, "anthropic");
822            }
823            _ => panic!(
824                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
825                error
826            ),
827        }
828
829        let error = LanguageModelCompletionError::from_cloud_failure(
830            String::from("anthropic").into(),
831            "upstream_http_error".to_string(),
832            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
833            None,
834        );
835
836        match error {
837            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
838                assert_eq!(provider.0, "anthropic");
839                assert_eq!(
840                    message,
841                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
842                );
843            }
844            _ => panic!(
845                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
846                error
847            ),
848        }
849    }
850}