language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 15use futures::FutureExt;
 16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 18use http_client::{StatusCode, http};
 19use icons::IconName;
 20use parking_lot::Mutex;
 21use schemars::JsonSchema;
 22use serde::{Deserialize, Serialize, de::DeserializeOwned};
 23use std::ops::{Add, Sub};
 24use std::str::FromStr;
 25use std::sync::Arc;
 26use std::time::Duration;
 27use std::{fmt, io};
 28use thiserror::Error;
 29use util::serde::is_default;
 30
 31pub use crate::model::*;
 32pub use crate::rate_limiter::*;
 33pub use crate::registry::*;
 34pub use crate::request::*;
 35pub use crate::role::*;
 36pub use crate::telemetry::*;
 37
 38pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 39    LanguageModelProviderId::new("anthropic");
 40pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 41    LanguageModelProviderName::new("Anthropic");
 42
 43pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 44pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 45    LanguageModelProviderName::new("Google AI");
 46
 47pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 48pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 49    LanguageModelProviderName::new("OpenAI");
 50
 51pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 52pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 53    LanguageModelProviderName::new("Zed");
 54
 55pub fn init(client: Arc<Client>, cx: &mut App) {
 56    init_settings(cx);
 57    RefreshLlmTokenListener::register(client, cx);
 58}
 59
 60pub fn init_settings(cx: &mut App) {
 61    registry::init(cx);
 62}
 63
 64/// Configuration for caching language model messages.
 65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 66pub struct LanguageModelCacheConfiguration {
 67    pub max_cache_anchors: usize,
 68    pub should_speculate: bool,
 69    pub min_total_token: u64,
 70}
 71
 72/// A completion event from a language model.
 73#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 74pub enum LanguageModelCompletionEvent {
 75    StatusUpdate(CompletionRequestStatus),
 76    Stop(StopReason),
 77    Text(String),
 78    Thinking {
 79        text: String,
 80        signature: Option<String>,
 81    },
 82    RedactedThinking {
 83        data: String,
 84    },
 85    ToolUse(LanguageModelToolUse),
 86    ToolUseJsonParseError {
 87        id: LanguageModelToolUseId,
 88        tool_name: Arc<str>,
 89        raw_input: Arc<str>,
 90        json_parse_error: String,
 91    },
 92    StartMessage {
 93        message_id: String,
 94    },
 95    UsageUpdate(TokenUsage),
 96}
 97
 98#[derive(Error, Debug)]
 99pub enum LanguageModelCompletionError {
100    #[error("prompt too large for context window")]
101    PromptTooLarge { tokens: Option<u64> },
102    #[error("missing {provider} API key")]
103    NoApiKey { provider: LanguageModelProviderName },
104    #[error("{provider}'s API rate limit exceeded")]
105    RateLimitExceeded {
106        provider: LanguageModelProviderName,
107        retry_after: Option<Duration>,
108    },
109    #[error("{provider}'s API servers are overloaded right now")]
110    ServerOverloaded {
111        provider: LanguageModelProviderName,
112        retry_after: Option<Duration>,
113    },
114    #[error("{provider}'s API server reported an internal server error: {message}")]
115    ApiInternalServerError {
116        provider: LanguageModelProviderName,
117        message: String,
118    },
119    #[error("{message}")]
120    UpstreamProviderError {
121        message: String,
122        status: StatusCode,
123        retry_after: Option<Duration>,
124    },
125    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
126    HttpResponseError {
127        provider: LanguageModelProviderName,
128        status_code: StatusCode,
129        message: String,
130    },
131
132    // Client errors
133    #[error("invalid request format to {provider}'s API: {message}")]
134    BadRequestFormat {
135        provider: LanguageModelProviderName,
136        message: String,
137    },
138    #[error("authentication error with {provider}'s API: {message}")]
139    AuthenticationError {
140        provider: LanguageModelProviderName,
141        message: String,
142    },
143    #[error("permission error with {provider}'s API: {message}")]
144    PermissionError {
145        provider: LanguageModelProviderName,
146        message: String,
147    },
148    #[error("language model provider API endpoint not found")]
149    ApiEndpointNotFound { provider: LanguageModelProviderName },
150    #[error("I/O error reading response from {provider}'s API")]
151    ApiReadResponseError {
152        provider: LanguageModelProviderName,
153        #[source]
154        error: io::Error,
155    },
156    #[error("error serializing request to {provider} API")]
157    SerializeRequest {
158        provider: LanguageModelProviderName,
159        #[source]
160        error: serde_json::Error,
161    },
162    #[error("error building request body to {provider} API")]
163    BuildRequestBody {
164        provider: LanguageModelProviderName,
165        #[source]
166        error: http::Error,
167    },
168    #[error("error sending HTTP request to {provider} API")]
169    HttpSend {
170        provider: LanguageModelProviderName,
171        #[source]
172        error: anyhow::Error,
173    },
174    #[error("error deserializing {provider} API response")]
175    DeserializeResponse {
176        provider: LanguageModelProviderName,
177        #[source]
178        error: serde_json::Error,
179    },
180
181    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
182    #[error(transparent)]
183    Other(#[from] anyhow::Error),
184}
185
186impl LanguageModelCompletionError {
187    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
188        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
189        let upstream_status = error_json
190            .get("upstream_status")
191            .and_then(|v| v.as_u64())
192            .and_then(|status| u16::try_from(status).ok())
193            .and_then(|status| StatusCode::from_u16(status).ok())?;
194        let inner_message = error_json
195            .get("message")
196            .and_then(|v| v.as_str())
197            .unwrap_or(message)
198            .to_string();
199        Some((upstream_status, inner_message))
200    }
201
202    pub fn from_cloud_failure(
203        upstream_provider: LanguageModelProviderName,
204        code: String,
205        message: String,
206        retry_after: Option<Duration>,
207    ) -> Self {
208        if let Some(tokens) = parse_prompt_too_long(&message) {
209            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
210            // to be reported. This is a temporary workaround to handle this in the case where the
211            // token limit has been exceeded.
212            Self::PromptTooLarge {
213                tokens: Some(tokens),
214            }
215        } else if code == "upstream_http_error" {
216            if let Some((upstream_status, inner_message)) =
217                Self::parse_upstream_error_json(&message)
218            {
219                return Self::from_http_status(
220                    upstream_provider,
221                    upstream_status,
222                    inner_message,
223                    retry_after,
224                );
225            }
226            anyhow!("completion request failed, code: {code}, message: {message}").into()
227        } else if let Some(status_code) = code
228            .strip_prefix("upstream_http_")
229            .and_then(|code| StatusCode::from_str(code).ok())
230        {
231            Self::from_http_status(upstream_provider, status_code, message, retry_after)
232        } else if let Some(status_code) = code
233            .strip_prefix("http_")
234            .and_then(|code| StatusCode::from_str(code).ok())
235        {
236            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
237        } else {
238            anyhow!("completion request failed, code: {code}, message: {message}").into()
239        }
240    }
241
242    pub fn from_http_status(
243        provider: LanguageModelProviderName,
244        status_code: StatusCode,
245        message: String,
246        retry_after: Option<Duration>,
247    ) -> Self {
248        match status_code {
249            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
250            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
251            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
252            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
253            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
254                tokens: parse_prompt_too_long(&message),
255            },
256            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
257                provider,
258                retry_after,
259            },
260            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
261            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
262                provider,
263                retry_after,
264            },
265            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
266                provider,
267                retry_after,
268            },
269            _ => Self::HttpResponseError {
270                provider,
271                status_code,
272                message,
273            },
274        }
275    }
276}
277
278impl From<AnthropicError> for LanguageModelCompletionError {
279    fn from(error: AnthropicError) -> Self {
280        let provider = ANTHROPIC_PROVIDER_NAME;
281        match error {
282            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
283            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
284            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
285            AnthropicError::DeserializeResponse(error) => {
286                Self::DeserializeResponse { provider, error }
287            }
288            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
289            AnthropicError::HttpResponseError {
290                status_code,
291                message,
292            } => Self::HttpResponseError {
293                provider,
294                status_code,
295                message,
296            },
297            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
298                provider,
299                retry_after: Some(retry_after),
300            },
301            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
302                provider,
303                retry_after,
304            },
305            AnthropicError::ApiError(api_error) => api_error.into(),
306        }
307    }
308}
309
310impl From<anthropic::ApiError> for LanguageModelCompletionError {
311    fn from(error: anthropic::ApiError) -> Self {
312        use anthropic::ApiErrorCode::*;
313        let provider = ANTHROPIC_PROVIDER_NAME;
314        match error.code() {
315            Some(code) => match code {
316                InvalidRequestError => Self::BadRequestFormat {
317                    provider,
318                    message: error.message,
319                },
320                AuthenticationError => Self::AuthenticationError {
321                    provider,
322                    message: error.message,
323                },
324                PermissionError => Self::PermissionError {
325                    provider,
326                    message: error.message,
327                },
328                NotFoundError => Self::ApiEndpointNotFound { provider },
329                RequestTooLarge => Self::PromptTooLarge {
330                    tokens: parse_prompt_too_long(&error.message),
331                },
332                RateLimitError => Self::RateLimitExceeded {
333                    provider,
334                    retry_after: None,
335                },
336                ApiError => Self::ApiInternalServerError {
337                    provider,
338                    message: error.message,
339                },
340                OverloadedError => Self::ServerOverloaded {
341                    provider,
342                    retry_after: None,
343                },
344            },
345            None => Self::Other(error.into()),
346        }
347    }
348}
349
350/// Indicates the format used to define the input schema for a language model tool.
351#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
352pub enum LanguageModelToolSchemaFormat {
353    /// A JSON schema, see https://json-schema.org
354    JsonSchema,
355    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
356    JsonSchemaSubset,
357}
358
359#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
360#[serde(rename_all = "snake_case")]
361pub enum StopReason {
362    EndTurn,
363    MaxTokens,
364    ToolUse,
365    Refusal,
366}
367
368#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
369pub struct TokenUsage {
370    #[serde(default, skip_serializing_if = "is_default")]
371    pub input_tokens: u64,
372    #[serde(default, skip_serializing_if = "is_default")]
373    pub output_tokens: u64,
374    #[serde(default, skip_serializing_if = "is_default")]
375    pub cache_creation_input_tokens: u64,
376    #[serde(default, skip_serializing_if = "is_default")]
377    pub cache_read_input_tokens: u64,
378}
379
380impl TokenUsage {
381    pub fn total_tokens(&self) -> u64 {
382        self.input_tokens
383            + self.output_tokens
384            + self.cache_read_input_tokens
385            + self.cache_creation_input_tokens
386    }
387}
388
389impl Add<TokenUsage> for TokenUsage {
390    type Output = Self;
391
392    fn add(self, other: Self) -> Self {
393        Self {
394            input_tokens: self.input_tokens + other.input_tokens,
395            output_tokens: self.output_tokens + other.output_tokens,
396            cache_creation_input_tokens: self.cache_creation_input_tokens
397                + other.cache_creation_input_tokens,
398            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
399        }
400    }
401}
402
403impl Sub<TokenUsage> for TokenUsage {
404    type Output = Self;
405
406    fn sub(self, other: Self) -> Self {
407        Self {
408            input_tokens: self.input_tokens - other.input_tokens,
409            output_tokens: self.output_tokens - other.output_tokens,
410            cache_creation_input_tokens: self.cache_creation_input_tokens
411                - other.cache_creation_input_tokens,
412            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
413        }
414    }
415}
416
417#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
418pub struct LanguageModelToolUseId(Arc<str>);
419
420impl fmt::Display for LanguageModelToolUseId {
421    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
422        write!(f, "{}", self.0)
423    }
424}
425
426impl<T> From<T> for LanguageModelToolUseId
427where
428    T: Into<Arc<str>>,
429{
430    fn from(value: T) -> Self {
431        Self(value.into())
432    }
433}
434
435#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
436pub struct LanguageModelToolUse {
437    pub id: LanguageModelToolUseId,
438    pub name: Arc<str>,
439    pub raw_input: String,
440    pub input: serde_json::Value,
441    pub is_input_complete: bool,
442}
443
444pub struct LanguageModelTextStream {
445    pub message_id: Option<String>,
446    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
447    // Has complete token usage after the stream has finished
448    pub last_token_usage: Arc<Mutex<TokenUsage>>,
449}
450
451impl Default for LanguageModelTextStream {
452    fn default() -> Self {
453        Self {
454            message_id: None,
455            stream: Box::pin(futures::stream::empty()),
456            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
457        }
458    }
459}
460
461pub trait LanguageModel: Send + Sync {
462    fn id(&self) -> LanguageModelId;
463    fn name(&self) -> LanguageModelName;
464    fn provider_id(&self) -> LanguageModelProviderId;
465    fn provider_name(&self) -> LanguageModelProviderName;
466    fn upstream_provider_id(&self) -> LanguageModelProviderId {
467        self.provider_id()
468    }
469    fn upstream_provider_name(&self) -> LanguageModelProviderName {
470        self.provider_name()
471    }
472
473    fn telemetry_id(&self) -> String;
474
475    fn api_key(&self, _cx: &App) -> Option<String> {
476        None
477    }
478
479    /// Whether this model supports images
480    fn supports_images(&self) -> bool;
481
482    /// Whether this model supports tools.
483    fn supports_tools(&self) -> bool;
484
485    /// Whether this model supports choosing which tool to use.
486    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
487
488    /// Returns whether this model supports "burn mode";
489    fn supports_burn_mode(&self) -> bool {
490        false
491    }
492
493    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
494        LanguageModelToolSchemaFormat::JsonSchema
495    }
496
497    fn max_token_count(&self) -> u64;
498    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
499    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
500        None
501    }
502    fn max_output_tokens(&self) -> Option<u64> {
503        None
504    }
505
506    fn count_tokens(
507        &self,
508        request: LanguageModelRequest,
509        cx: &App,
510    ) -> BoxFuture<'static, Result<u64>>;
511
512    fn stream_completion(
513        &self,
514        request: LanguageModelRequest,
515        cx: &AsyncApp,
516    ) -> BoxFuture<
517        'static,
518        Result<
519            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
520            LanguageModelCompletionError,
521        >,
522    >;
523
524    fn stream_completion_text(
525        &self,
526        request: LanguageModelRequest,
527        cx: &AsyncApp,
528    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
529        let future = self.stream_completion(request, cx);
530
531        async move {
532            let events = future.await?;
533            let mut events = events.fuse();
534            let mut message_id = None;
535            let mut first_item_text = None;
536            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
537
538            if let Some(first_event) = events.next().await {
539                match first_event {
540                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
541                        message_id = Some(id);
542                    }
543                    Ok(LanguageModelCompletionEvent::Text(text)) => {
544                        first_item_text = Some(text);
545                    }
546                    _ => (),
547                }
548            }
549
550            let stream = futures::stream::iter(first_item_text.map(Ok))
551                .chain(events.filter_map({
552                    let last_token_usage = last_token_usage.clone();
553                    move |result| {
554                        let last_token_usage = last_token_usage.clone();
555                        async move {
556                            match result {
557                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
558                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
559                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
560                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
561                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
562                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
563                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
564                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
565                                    ..
566                                }) => None,
567                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
568                                    *last_token_usage.lock() = token_usage;
569                                    None
570                                }
571                                Err(err) => Some(Err(err)),
572                            }
573                        }
574                    }
575                }))
576                .boxed();
577
578            Ok(LanguageModelTextStream {
579                message_id,
580                stream,
581                last_token_usage,
582            })
583        }
584        .boxed()
585    }
586
587    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
588        None
589    }
590
591    #[cfg(any(test, feature = "test-support"))]
592    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
593        unimplemented!()
594    }
595}
596
597pub trait LanguageModelExt: LanguageModel {
598    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
599        match mode {
600            CompletionMode::Normal => self.max_token_count(),
601            CompletionMode::Max => self
602                .max_token_count_in_burn_mode()
603                .unwrap_or_else(|| self.max_token_count()),
604        }
605    }
606}
607impl LanguageModelExt for dyn LanguageModel {}
608
609pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
610    fn name() -> String;
611    fn description() -> String;
612}
613
614/// An error that occurred when trying to authenticate the language model provider.
615#[derive(Debug, Error)]
616pub enum AuthenticateError {
617    #[error("credentials not found")]
618    CredentialsNotFound,
619    #[error(transparent)]
620    Other(#[from] anyhow::Error),
621}
622
623pub trait LanguageModelProvider: 'static {
624    fn id(&self) -> LanguageModelProviderId;
625    fn name(&self) -> LanguageModelProviderName;
626    fn icon(&self) -> IconName {
627        IconName::ZedAssistant
628    }
629    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
630    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
631    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
632    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
633        Vec::new()
634    }
635    fn is_authenticated(&self, cx: &App) -> bool;
636    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
637    fn configuration_view(
638        &self,
639        target_agent: ConfigurationViewTargetAgent,
640        window: &mut Window,
641        cx: &mut App,
642    ) -> AnyView;
643    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
644}
645
646#[derive(Default, Clone)]
647pub enum ConfigurationViewTargetAgent {
648    #[default]
649    ZedAgent,
650    Other(SharedString),
651}
652
653#[derive(PartialEq, Eq)]
654pub enum LanguageModelProviderTosView {
655    /// When there are some past interactions in the Agent Panel.
656    ThreadEmptyState,
657    /// When there are no past interactions in the Agent Panel.
658    ThreadFreshStart,
659    TextThreadPopup,
660    Configuration,
661}
662
663pub trait LanguageModelProviderState: 'static {
664    type ObservableEntity;
665
666    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
667
668    fn subscribe<T: 'static>(
669        &self,
670        cx: &mut gpui::Context<T>,
671        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
672    ) -> Option<gpui::Subscription> {
673        let entity = self.observable_entity()?;
674        Some(cx.observe(&entity, move |this, _, cx| {
675            callback(this, cx);
676        }))
677    }
678}
679
680#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
681pub struct LanguageModelId(pub SharedString);
682
683#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
684pub struct LanguageModelName(pub SharedString);
685
686#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
687pub struct LanguageModelProviderId(pub SharedString);
688
689#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
690pub struct LanguageModelProviderName(pub SharedString);
691
692impl LanguageModelProviderId {
693    pub const fn new(id: &'static str) -> Self {
694        Self(SharedString::new_static(id))
695    }
696}
697
698impl LanguageModelProviderName {
699    pub const fn new(id: &'static str) -> Self {
700        Self(SharedString::new_static(id))
701    }
702}
703
704impl fmt::Display for LanguageModelProviderId {
705    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
706        write!(f, "{}", self.0)
707    }
708}
709
710impl fmt::Display for LanguageModelProviderName {
711    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
712        write!(f, "{}", self.0)
713    }
714}
715
716impl From<String> for LanguageModelId {
717    fn from(value: String) -> Self {
718        Self(SharedString::from(value))
719    }
720}
721
722impl From<String> for LanguageModelName {
723    fn from(value: String) -> Self {
724        Self(SharedString::from(value))
725    }
726}
727
728impl From<String> for LanguageModelProviderId {
729    fn from(value: String) -> Self {
730        Self(SharedString::from(value))
731    }
732}
733
734impl From<String> for LanguageModelProviderName {
735    fn from(value: String) -> Self {
736        Self(SharedString::from(value))
737    }
738}
739
740impl From<Arc<str>> for LanguageModelProviderId {
741    fn from(value: Arc<str>) -> Self {
742        Self(SharedString::from(value))
743    }
744}
745
746impl From<Arc<str>> for LanguageModelProviderName {
747    fn from(value: Arc<str>) -> Self {
748        Self(SharedString::from(value))
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755
756    #[test]
757    fn test_from_cloud_failure_with_upstream_http_error() {
758        let error = LanguageModelCompletionError::from_cloud_failure(
759            String::from("anthropic").into(),
760            "upstream_http_error".to_string(),
761            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
762            None,
763        );
764
765        match error {
766            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
767                assert_eq!(provider.0, "anthropic");
768            }
769            _ => panic!(
770                "Expected ServerOverloaded error for 503 status, got: {:?}",
771                error
772            ),
773        }
774
775        let error = LanguageModelCompletionError::from_cloud_failure(
776            String::from("anthropic").into(),
777            "upstream_http_error".to_string(),
778            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
779            None,
780        );
781
782        match error {
783            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
784                assert_eq!(provider.0, "anthropic");
785                assert_eq!(message, "Internal server error");
786            }
787            _ => panic!(
788                "Expected ApiInternalServerError for 500 status, got: {:?}",
789                error
790            ),
791        }
792    }
793
794    #[test]
795    fn test_from_cloud_failure_with_standard_format() {
796        let error = LanguageModelCompletionError::from_cloud_failure(
797            String::from("anthropic").into(),
798            "upstream_http_503".to_string(),
799            "Service unavailable".to_string(),
800            None,
801        );
802
803        match error {
804            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
805                assert_eq!(provider.0, "anthropic");
806            }
807            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
808        }
809    }
810
811    #[test]
812    fn test_upstream_http_error_connection_timeout() {
813        let error = LanguageModelCompletionError::from_cloud_failure(
814            String::from("anthropic").into(),
815            "upstream_http_error".to_string(),
816            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
817            None,
818        );
819
820        match error {
821            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
822                assert_eq!(provider.0, "anthropic");
823            }
824            _ => panic!(
825                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
826                error
827            ),
828        }
829
830        let error = LanguageModelCompletionError::from_cloud_failure(
831            String::from("anthropic").into(),
832            "upstream_http_error".to_string(),
833            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
834            None,
835        );
836
837        match error {
838            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
839                assert_eq!(provider.0, "anthropic");
840                assert_eq!(
841                    message,
842                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
843                );
844            }
845            _ => panic!(
846                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
847                error
848            ),
849        }
850    }
851}