language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 15use futures::FutureExt;
 16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 18use http_client::{StatusCode, http};
 19use icons::IconName;
 20use open_router::OpenRouterError;
 21use parking_lot::Mutex;
 22use schemars::JsonSchema;
 23use serde::{Deserialize, Serialize, de::DeserializeOwned};
 24use std::ops::{Add, Sub};
 25use std::str::FromStr;
 26use std::sync::Arc;
 27use std::time::Duration;
 28use std::{fmt, io};
 29use thiserror::Error;
 30use util::serde::is_default;
 31
 32pub use crate::model::*;
 33pub use crate::rate_limiter::*;
 34pub use crate::registry::*;
 35pub use crate::request::*;
 36pub use crate::role::*;
 37pub use crate::telemetry::*;
 38
 39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 40    LanguageModelProviderId::new("anthropic");
 41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 42    LanguageModelProviderName::new("Anthropic");
 43
 44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 46    LanguageModelProviderName::new("Google AI");
 47
 48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 50    LanguageModelProviderName::new("OpenAI");
 51
 52pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 53pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 54    LanguageModelProviderName::new("Zed");
 55
 56pub fn init(client: Arc<Client>, cx: &mut App) {
 57    init_settings(cx);
 58    RefreshLlmTokenListener::register(client, cx);
 59}
 60
 61pub fn init_settings(cx: &mut App) {
 62    registry::init(cx);
 63}
 64
 65/// Configuration for caching language model messages.
 66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 67pub struct LanguageModelCacheConfiguration {
 68    pub max_cache_anchors: usize,
 69    pub should_speculate: bool,
 70    pub min_total_token: u64,
 71}
 72
 73/// A completion event from a language model.
 74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 75pub enum LanguageModelCompletionEvent {
 76    StatusUpdate(CompletionRequestStatus),
 77    Stop(StopReason),
 78    Text(String),
 79    Thinking {
 80        text: String,
 81        signature: Option<String>,
 82    },
 83    RedactedThinking {
 84        data: String,
 85    },
 86    ToolUse(LanguageModelToolUse),
 87    ToolUseJsonParseError {
 88        id: LanguageModelToolUseId,
 89        tool_name: Arc<str>,
 90        raw_input: Arc<str>,
 91        json_parse_error: String,
 92    },
 93    StartMessage {
 94        message_id: String,
 95    },
 96    UsageUpdate(TokenUsage),
 97}
 98
 99#[derive(Error, Debug)]
100pub enum LanguageModelCompletionError {
101    #[error("prompt too large for context window")]
102    PromptTooLarge { tokens: Option<u64> },
103    #[error("missing {provider} API key")]
104    NoApiKey { provider: LanguageModelProviderName },
105    #[error("{provider}'s API rate limit exceeded")]
106    RateLimitExceeded {
107        provider: LanguageModelProviderName,
108        retry_after: Option<Duration>,
109    },
110    #[error("{provider}'s API servers are overloaded right now")]
111    ServerOverloaded {
112        provider: LanguageModelProviderName,
113        retry_after: Option<Duration>,
114    },
115    #[error("{provider}'s API server reported an internal server error: {message}")]
116    ApiInternalServerError {
117        provider: LanguageModelProviderName,
118        message: String,
119    },
120    #[error("{message}")]
121    UpstreamProviderError {
122        message: String,
123        status: StatusCode,
124        retry_after: Option<Duration>,
125    },
126    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
127    HttpResponseError {
128        provider: LanguageModelProviderName,
129        status_code: StatusCode,
130        message: String,
131    },
132
133    // Client errors
134    #[error("invalid request format to {provider}'s API: {message}")]
135    BadRequestFormat {
136        provider: LanguageModelProviderName,
137        message: String,
138    },
139    #[error("authentication error with {provider}'s API: {message}")]
140    AuthenticationError {
141        provider: LanguageModelProviderName,
142        message: String,
143    },
144    #[error("permission error with {provider}'s API: {message}")]
145    PermissionError {
146        provider: LanguageModelProviderName,
147        message: String,
148    },
149    #[error("language model provider API endpoint not found")]
150    ApiEndpointNotFound { provider: LanguageModelProviderName },
151    #[error("I/O error reading response from {provider}'s API")]
152    ApiReadResponseError {
153        provider: LanguageModelProviderName,
154        #[source]
155        error: io::Error,
156    },
157    #[error("error serializing request to {provider} API")]
158    SerializeRequest {
159        provider: LanguageModelProviderName,
160        #[source]
161        error: serde_json::Error,
162    },
163    #[error("error building request body to {provider} API")]
164    BuildRequestBody {
165        provider: LanguageModelProviderName,
166        #[source]
167        error: http::Error,
168    },
169    #[error("error sending HTTP request to {provider} API")]
170    HttpSend {
171        provider: LanguageModelProviderName,
172        #[source]
173        error: anyhow::Error,
174    },
175    #[error("error deserializing {provider} API response")]
176    DeserializeResponse {
177        provider: LanguageModelProviderName,
178        #[source]
179        error: serde_json::Error,
180    },
181
182    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
183    #[error(transparent)]
184    Other(#[from] anyhow::Error),
185}
186
187impl LanguageModelCompletionError {
188    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
189        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
190        let upstream_status = error_json
191            .get("upstream_status")
192            .and_then(|v| v.as_u64())
193            .and_then(|status| u16::try_from(status).ok())
194            .and_then(|status| StatusCode::from_u16(status).ok())?;
195        let inner_message = error_json
196            .get("message")
197            .and_then(|v| v.as_str())
198            .unwrap_or(message)
199            .to_string();
200        Some((upstream_status, inner_message))
201    }
202
203    pub fn from_cloud_failure(
204        upstream_provider: LanguageModelProviderName,
205        code: String,
206        message: String,
207        retry_after: Option<Duration>,
208    ) -> Self {
209        if let Some(tokens) = parse_prompt_too_long(&message) {
210            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
211            // to be reported. This is a temporary workaround to handle this in the case where the
212            // token limit has been exceeded.
213            Self::PromptTooLarge {
214                tokens: Some(tokens),
215            }
216        } else if code == "upstream_http_error" {
217            if let Some((upstream_status, inner_message)) =
218                Self::parse_upstream_error_json(&message)
219            {
220                return Self::from_http_status(
221                    upstream_provider,
222                    upstream_status,
223                    inner_message,
224                    retry_after,
225                );
226            }
227            anyhow!("completion request failed, code: {code}, message: {message}").into()
228        } else if let Some(status_code) = code
229            .strip_prefix("upstream_http_")
230            .and_then(|code| StatusCode::from_str(code).ok())
231        {
232            Self::from_http_status(upstream_provider, status_code, message, retry_after)
233        } else if let Some(status_code) = code
234            .strip_prefix("http_")
235            .and_then(|code| StatusCode::from_str(code).ok())
236        {
237            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
238        } else {
239            anyhow!("completion request failed, code: {code}, message: {message}").into()
240        }
241    }
242
243    pub fn from_http_status(
244        provider: LanguageModelProviderName,
245        status_code: StatusCode,
246        message: String,
247        retry_after: Option<Duration>,
248    ) -> Self {
249        match status_code {
250            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
251            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
252            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
253            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
254            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
255                tokens: parse_prompt_too_long(&message),
256            },
257            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
258                provider,
259                retry_after,
260            },
261            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
262            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
263                provider,
264                retry_after,
265            },
266            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
267                provider,
268                retry_after,
269            },
270            _ => Self::HttpResponseError {
271                provider,
272                status_code,
273                message,
274            },
275        }
276    }
277}
278
279impl From<AnthropicError> for LanguageModelCompletionError {
280    fn from(error: AnthropicError) -> Self {
281        let provider = ANTHROPIC_PROVIDER_NAME;
282        match error {
283            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
284            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
285            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
286            AnthropicError::DeserializeResponse(error) => {
287                Self::DeserializeResponse { provider, error }
288            }
289            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
290            AnthropicError::HttpResponseError {
291                status_code,
292                message,
293            } => Self::HttpResponseError {
294                provider,
295                status_code,
296                message,
297            },
298            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
299                provider,
300                retry_after: Some(retry_after),
301            },
302            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
303                provider,
304                retry_after,
305            },
306            AnthropicError::ApiError(api_error) => api_error.into(),
307        }
308    }
309}
310
311impl From<anthropic::ApiError> for LanguageModelCompletionError {
312    fn from(error: anthropic::ApiError) -> Self {
313        use anthropic::ApiErrorCode::*;
314        let provider = ANTHROPIC_PROVIDER_NAME;
315        match error.code() {
316            Some(code) => match code {
317                InvalidRequestError => Self::BadRequestFormat {
318                    provider,
319                    message: error.message,
320                },
321                AuthenticationError => Self::AuthenticationError {
322                    provider,
323                    message: error.message,
324                },
325                PermissionError => Self::PermissionError {
326                    provider,
327                    message: error.message,
328                },
329                NotFoundError => Self::ApiEndpointNotFound { provider },
330                RequestTooLarge => Self::PromptTooLarge {
331                    tokens: parse_prompt_too_long(&error.message),
332                },
333                RateLimitError => Self::RateLimitExceeded {
334                    provider,
335                    retry_after: None,
336                },
337                ApiError => Self::ApiInternalServerError {
338                    provider,
339                    message: error.message,
340                },
341                OverloadedError => Self::ServerOverloaded {
342                    provider,
343                    retry_after: None,
344                },
345            },
346            None => Self::Other(error.into()),
347        }
348    }
349}
350
351impl From<OpenRouterError> for LanguageModelCompletionError {
352    fn from(error: OpenRouterError) -> Self {
353        let provider = LanguageModelProviderName::new("OpenRouter");
354        match error {
355            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
356            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
357            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
358            OpenRouterError::DeserializeResponse(error) => {
359                Self::DeserializeResponse { provider, error }
360            }
361            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
362            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
363                provider,
364                retry_after: Some(retry_after),
365            },
366            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
367                provider,
368                retry_after,
369            },
370            OpenRouterError::ApiError(api_error) => api_error.into(),
371        }
372    }
373}
374
375impl From<open_router::ApiError> for LanguageModelCompletionError {
376    fn from(error: open_router::ApiError) -> Self {
377        use open_router::ApiErrorCode::*;
378        let provider = LanguageModelProviderName::new("OpenRouter");
379        match error.code {
380            InvalidRequestError => Self::BadRequestFormat {
381                provider,
382                message: error.message,
383            },
384            AuthenticationError => Self::AuthenticationError {
385                provider,
386                message: error.message,
387            },
388            PaymentRequiredError => Self::AuthenticationError {
389                provider,
390                message: format!("Payment required: {}", error.message),
391            },
392            PermissionError => Self::PermissionError {
393                provider,
394                message: error.message,
395            },
396            RequestTimedOut => Self::HttpResponseError {
397                provider,
398                status_code: StatusCode::REQUEST_TIMEOUT,
399                message: error.message,
400            },
401            RateLimitError => Self::RateLimitExceeded {
402                provider,
403                retry_after: None,
404            },
405            ApiError => Self::ApiInternalServerError {
406                provider,
407                message: error.message,
408            },
409            OverloadedError => Self::ServerOverloaded {
410                provider,
411                retry_after: None,
412            },
413        }
414    }
415}
416
417/// Indicates the format used to define the input schema for a language model tool.
418#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
419pub enum LanguageModelToolSchemaFormat {
420    /// A JSON schema, see https://json-schema.org
421    JsonSchema,
422    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
423    JsonSchemaSubset,
424}
425
426#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
427#[serde(rename_all = "snake_case")]
428pub enum StopReason {
429    EndTurn,
430    MaxTokens,
431    ToolUse,
432    Refusal,
433}
434
435#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
436pub struct TokenUsage {
437    #[serde(default, skip_serializing_if = "is_default")]
438    pub input_tokens: u64,
439    #[serde(default, skip_serializing_if = "is_default")]
440    pub output_tokens: u64,
441    #[serde(default, skip_serializing_if = "is_default")]
442    pub cache_creation_input_tokens: u64,
443    #[serde(default, skip_serializing_if = "is_default")]
444    pub cache_read_input_tokens: u64,
445}
446
447impl TokenUsage {
448    pub fn total_tokens(&self) -> u64 {
449        self.input_tokens
450            + self.output_tokens
451            + self.cache_read_input_tokens
452            + self.cache_creation_input_tokens
453    }
454}
455
456impl Add<TokenUsage> for TokenUsage {
457    type Output = Self;
458
459    fn add(self, other: Self) -> Self {
460        Self {
461            input_tokens: self.input_tokens + other.input_tokens,
462            output_tokens: self.output_tokens + other.output_tokens,
463            cache_creation_input_tokens: self.cache_creation_input_tokens
464                + other.cache_creation_input_tokens,
465            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
466        }
467    }
468}
469
470impl Sub<TokenUsage> for TokenUsage {
471    type Output = Self;
472
473    fn sub(self, other: Self) -> Self {
474        Self {
475            input_tokens: self.input_tokens - other.input_tokens,
476            output_tokens: self.output_tokens - other.output_tokens,
477            cache_creation_input_tokens: self.cache_creation_input_tokens
478                - other.cache_creation_input_tokens,
479            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
480        }
481    }
482}
483
484#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
485pub struct LanguageModelToolUseId(Arc<str>);
486
487impl fmt::Display for LanguageModelToolUseId {
488    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
489        write!(f, "{}", self.0)
490    }
491}
492
493impl<T> From<T> for LanguageModelToolUseId
494where
495    T: Into<Arc<str>>,
496{
497    fn from(value: T) -> Self {
498        Self(value.into())
499    }
500}
501
502#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
503pub struct LanguageModelToolUse {
504    pub id: LanguageModelToolUseId,
505    pub name: Arc<str>,
506    pub raw_input: String,
507    pub input: serde_json::Value,
508    pub is_input_complete: bool,
509}
510
511pub struct LanguageModelTextStream {
512    pub message_id: Option<String>,
513    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
514    // Has complete token usage after the stream has finished
515    pub last_token_usage: Arc<Mutex<TokenUsage>>,
516}
517
518impl Default for LanguageModelTextStream {
519    fn default() -> Self {
520        Self {
521            message_id: None,
522            stream: Box::pin(futures::stream::empty()),
523            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
524        }
525    }
526}
527
528pub trait LanguageModel: Send + Sync {
529    fn id(&self) -> LanguageModelId;
530    fn name(&self) -> LanguageModelName;
531    fn provider_id(&self) -> LanguageModelProviderId;
532    fn provider_name(&self) -> LanguageModelProviderName;
533    fn upstream_provider_id(&self) -> LanguageModelProviderId {
534        self.provider_id()
535    }
536    fn upstream_provider_name(&self) -> LanguageModelProviderName {
537        self.provider_name()
538    }
539
540    fn telemetry_id(&self) -> String;
541
542    fn api_key(&self, _cx: &App) -> Option<String> {
543        None
544    }
545
546    /// Whether this model supports images
547    fn supports_images(&self) -> bool;
548
549    /// Whether this model supports tools.
550    fn supports_tools(&self) -> bool;
551
552    /// Whether this model supports choosing which tool to use.
553    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
554
555    /// Returns whether this model supports "burn mode";
556    fn supports_burn_mode(&self) -> bool {
557        false
558    }
559
560    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
561        LanguageModelToolSchemaFormat::JsonSchema
562    }
563
564    fn max_token_count(&self) -> u64;
565    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
566    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
567        None
568    }
569    fn max_output_tokens(&self) -> Option<u64> {
570        None
571    }
572
573    fn count_tokens(
574        &self,
575        request: LanguageModelRequest,
576        cx: &App,
577    ) -> BoxFuture<'static, Result<u64>>;
578
579    fn stream_completion(
580        &self,
581        request: LanguageModelRequest,
582        cx: &AsyncApp,
583    ) -> BoxFuture<
584        'static,
585        Result<
586            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
587            LanguageModelCompletionError,
588        >,
589    >;
590
591    fn stream_completion_text(
592        &self,
593        request: LanguageModelRequest,
594        cx: &AsyncApp,
595    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
596        let future = self.stream_completion(request, cx);
597
598        async move {
599            let events = future.await?;
600            let mut events = events.fuse();
601            let mut message_id = None;
602            let mut first_item_text = None;
603            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
604
605            if let Some(first_event) = events.next().await {
606                match first_event {
607                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
608                        message_id = Some(id);
609                    }
610                    Ok(LanguageModelCompletionEvent::Text(text)) => {
611                        first_item_text = Some(text);
612                    }
613                    _ => (),
614                }
615            }
616
617            let stream = futures::stream::iter(first_item_text.map(Ok))
618                .chain(events.filter_map({
619                    let last_token_usage = last_token_usage.clone();
620                    move |result| {
621                        let last_token_usage = last_token_usage.clone();
622                        async move {
623                            match result {
624                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
625                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
626                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
627                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
628                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
629                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
630                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
631                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
632                                    ..
633                                }) => None,
634                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
635                                    *last_token_usage.lock() = token_usage;
636                                    None
637                                }
638                                Err(err) => Some(Err(err)),
639                            }
640                        }
641                    }
642                }))
643                .boxed();
644
645            Ok(LanguageModelTextStream {
646                message_id,
647                stream,
648                last_token_usage,
649            })
650        }
651        .boxed()
652    }
653
654    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
655        None
656    }
657
658    #[cfg(any(test, feature = "test-support"))]
659    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
660        unimplemented!()
661    }
662}
663
664pub trait LanguageModelExt: LanguageModel {
665    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
666        match mode {
667            CompletionMode::Normal => self.max_token_count(),
668            CompletionMode::Max => self
669                .max_token_count_in_burn_mode()
670                .unwrap_or_else(|| self.max_token_count()),
671        }
672    }
673}
674impl LanguageModelExt for dyn LanguageModel {}
675
676pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
677    fn name() -> String;
678    fn description() -> String;
679}
680
681/// An error that occurred when trying to authenticate the language model provider.
682#[derive(Debug, Error)]
683pub enum AuthenticateError {
684    #[error("connection refused")]
685    ConnectionRefused,
686    #[error("credentials not found")]
687    CredentialsNotFound,
688    #[error(transparent)]
689    Other(#[from] anyhow::Error),
690}
691
692pub trait LanguageModelProvider: 'static {
693    fn id(&self) -> LanguageModelProviderId;
694    fn name(&self) -> LanguageModelProviderName;
695    fn icon(&self) -> IconName {
696        IconName::ZedAssistant
697    }
698    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
699    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
700    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
701    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
702        Vec::new()
703    }
704    fn is_authenticated(&self, cx: &App) -> bool;
705    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
706    fn configuration_view(
707        &self,
708        target_agent: ConfigurationViewTargetAgent,
709        window: &mut Window,
710        cx: &mut App,
711    ) -> AnyView;
712    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
713}
714
715#[derive(Default, Clone)]
716pub enum ConfigurationViewTargetAgent {
717    #[default]
718    ZedAgent,
719    Other(SharedString),
720}
721
722#[derive(PartialEq, Eq)]
723pub enum LanguageModelProviderTosView {
724    /// When there are some past interactions in the Agent Panel.
725    ThreadEmptyState,
726    /// When there are no past interactions in the Agent Panel.
727    ThreadFreshStart,
728    TextThreadPopup,
729    Configuration,
730}
731
732pub trait LanguageModelProviderState: 'static {
733    type ObservableEntity;
734
735    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
736
737    fn subscribe<T: 'static>(
738        &self,
739        cx: &mut gpui::Context<T>,
740        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
741    ) -> Option<gpui::Subscription> {
742        let entity = self.observable_entity()?;
743        Some(cx.observe(&entity, move |this, _, cx| {
744            callback(this, cx);
745        }))
746    }
747}
748
749#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
750pub struct LanguageModelId(pub SharedString);
751
752#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
753pub struct LanguageModelName(pub SharedString);
754
755#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
756pub struct LanguageModelProviderId(pub SharedString);
757
758#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
759pub struct LanguageModelProviderName(pub SharedString);
760
761impl LanguageModelProviderId {
762    pub const fn new(id: &'static str) -> Self {
763        Self(SharedString::new_static(id))
764    }
765}
766
767impl LanguageModelProviderName {
768    pub const fn new(id: &'static str) -> Self {
769        Self(SharedString::new_static(id))
770    }
771}
772
773impl fmt::Display for LanguageModelProviderId {
774    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
775        write!(f, "{}", self.0)
776    }
777}
778
779impl fmt::Display for LanguageModelProviderName {
780    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
781        write!(f, "{}", self.0)
782    }
783}
784
785impl From<String> for LanguageModelId {
786    fn from(value: String) -> Self {
787        Self(SharedString::from(value))
788    }
789}
790
791impl From<String> for LanguageModelName {
792    fn from(value: String) -> Self {
793        Self(SharedString::from(value))
794    }
795}
796
797impl From<String> for LanguageModelProviderId {
798    fn from(value: String) -> Self {
799        Self(SharedString::from(value))
800    }
801}
802
803impl From<String> for LanguageModelProviderName {
804    fn from(value: String) -> Self {
805        Self(SharedString::from(value))
806    }
807}
808
809impl From<Arc<str>> for LanguageModelProviderId {
810    fn from(value: Arc<str>) -> Self {
811        Self(SharedString::from(value))
812    }
813}
814
815impl From<Arc<str>> for LanguageModelProviderName {
816    fn from(value: Arc<str>) -> Self {
817        Self(SharedString::from(value))
818    }
819}
820
821#[cfg(test)]
822mod tests {
823    use super::*;
824
825    #[test]
826    fn test_from_cloud_failure_with_upstream_http_error() {
827        let error = LanguageModelCompletionError::from_cloud_failure(
828            String::from("anthropic").into(),
829            "upstream_http_error".to_string(),
830            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
831            None,
832        );
833
834        match error {
835            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
836                assert_eq!(provider.0, "anthropic");
837            }
838            _ => panic!(
839                "Expected ServerOverloaded error for 503 status, got: {:?}",
840                error
841            ),
842        }
843
844        let error = LanguageModelCompletionError::from_cloud_failure(
845            String::from("anthropic").into(),
846            "upstream_http_error".to_string(),
847            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
848            None,
849        );
850
851        match error {
852            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
853                assert_eq!(provider.0, "anthropic");
854                assert_eq!(message, "Internal server error");
855            }
856            _ => panic!(
857                "Expected ApiInternalServerError for 500 status, got: {:?}",
858                error
859            ),
860        }
861    }
862
863    #[test]
864    fn test_from_cloud_failure_with_standard_format() {
865        let error = LanguageModelCompletionError::from_cloud_failure(
866            String::from("anthropic").into(),
867            "upstream_http_503".to_string(),
868            "Service unavailable".to_string(),
869            None,
870        );
871
872        match error {
873            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
874                assert_eq!(provider.0, "anthropic");
875            }
876            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
877        }
878    }
879
880    #[test]
881    fn test_upstream_http_error_connection_timeout() {
882        let error = LanguageModelCompletionError::from_cloud_failure(
883            String::from("anthropic").into(),
884            "upstream_http_error".to_string(),
885            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
886            None,
887        );
888
889        match error {
890            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
891                assert_eq!(provider.0, "anthropic");
892            }
893            _ => panic!(
894                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
895                error
896            ),
897        }
898
899        let error = LanguageModelCompletionError::from_cloud_failure(
900            String::from("anthropic").into(),
901            "upstream_http_error".to_string(),
902            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
903            None,
904        );
905
906        match error {
907            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
908                assert_eq!(provider.0, "anthropic");
909                assert_eq!(
910                    message,
911                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
912                );
913            }
914            _ => panic!(
915                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
916                error
917            ),
918        }
919    }
920}