language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 15use futures::FutureExt;
 16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 18use http_client::{StatusCode, http};
 19use icons::IconName;
 20use open_router::OpenRouterError;
 21use parking_lot::Mutex;
 22use schemars::JsonSchema;
 23use serde::{Deserialize, Serialize, de::DeserializeOwned};
 24pub use settings::LanguageModelCacheConfiguration;
 25use std::ops::{Add, Sub};
 26use std::str::FromStr;
 27use std::sync::Arc;
 28use std::time::Duration;
 29use std::{fmt, io};
 30use thiserror::Error;
 31use util::serde::is_default;
 32
 33pub use crate::model::*;
 34pub use crate::rate_limiter::*;
 35pub use crate::registry::*;
 36pub use crate::request::*;
 37pub use crate::role::*;
 38pub use crate::telemetry::*;
 39
 40pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 41    LanguageModelProviderId::new("anthropic");
 42pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 43    LanguageModelProviderName::new("Anthropic");
 44
 45pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 46pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 47    LanguageModelProviderName::new("Google AI");
 48
 49pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 50pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 51    LanguageModelProviderName::new("OpenAI");
 52
 53pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 54pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 55    LanguageModelProviderName::new("Zed");
 56
 57pub fn init(client: Arc<Client>, cx: &mut App) {
 58    init_settings(cx);
 59    RefreshLlmTokenListener::register(client, cx);
 60}
 61
 62pub fn init_settings(cx: &mut App) {
 63    registry::init(cx);
 64}
 65
 66/// A completion event from a language model.
 67#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 68pub enum LanguageModelCompletionEvent {
 69    StatusUpdate(CompletionRequestStatus),
 70    Stop(StopReason),
 71    Text(String),
 72    Thinking {
 73        text: String,
 74        signature: Option<String>,
 75    },
 76    RedactedThinking {
 77        data: String,
 78    },
 79    ToolUse(LanguageModelToolUse),
 80    ToolUseJsonParseError {
 81        id: LanguageModelToolUseId,
 82        tool_name: Arc<str>,
 83        raw_input: Arc<str>,
 84        json_parse_error: String,
 85    },
 86    StartMessage {
 87        message_id: String,
 88    },
 89    UsageUpdate(TokenUsage),
 90}
 91
 92#[derive(Error, Debug)]
 93pub enum LanguageModelCompletionError {
 94    #[error("prompt too large for context window")]
 95    PromptTooLarge { tokens: Option<u64> },
 96    #[error("missing {provider} API key")]
 97    NoApiKey { provider: LanguageModelProviderName },
 98    #[error("{provider}'s API rate limit exceeded")]
 99    RateLimitExceeded {
100        provider: LanguageModelProviderName,
101        retry_after: Option<Duration>,
102    },
103    #[error("{provider}'s API servers are overloaded right now")]
104    ServerOverloaded {
105        provider: LanguageModelProviderName,
106        retry_after: Option<Duration>,
107    },
108    #[error("{provider}'s API server reported an internal server error: {message}")]
109    ApiInternalServerError {
110        provider: LanguageModelProviderName,
111        message: String,
112    },
113    #[error("{message}")]
114    UpstreamProviderError {
115        message: String,
116        status: StatusCode,
117        retry_after: Option<Duration>,
118    },
119    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
120    HttpResponseError {
121        provider: LanguageModelProviderName,
122        status_code: StatusCode,
123        message: String,
124    },
125
126    // Client errors
127    #[error("invalid request format to {provider}'s API: {message}")]
128    BadRequestFormat {
129        provider: LanguageModelProviderName,
130        message: String,
131    },
132    #[error("authentication error with {provider}'s API: {message}")]
133    AuthenticationError {
134        provider: LanguageModelProviderName,
135        message: String,
136    },
137    #[error("permission error with {provider}'s API: {message}")]
138    PermissionError {
139        provider: LanguageModelProviderName,
140        message: String,
141    },
142    #[error("language model provider API endpoint not found")]
143    ApiEndpointNotFound { provider: LanguageModelProviderName },
144    #[error("I/O error reading response from {provider}'s API")]
145    ApiReadResponseError {
146        provider: LanguageModelProviderName,
147        #[source]
148        error: io::Error,
149    },
150    #[error("error serializing request to {provider} API")]
151    SerializeRequest {
152        provider: LanguageModelProviderName,
153        #[source]
154        error: serde_json::Error,
155    },
156    #[error("error building request body to {provider} API")]
157    BuildRequestBody {
158        provider: LanguageModelProviderName,
159        #[source]
160        error: http::Error,
161    },
162    #[error("error sending HTTP request to {provider} API")]
163    HttpSend {
164        provider: LanguageModelProviderName,
165        #[source]
166        error: anyhow::Error,
167    },
168    #[error("error deserializing {provider} API response")]
169    DeserializeResponse {
170        provider: LanguageModelProviderName,
171        #[source]
172        error: serde_json::Error,
173    },
174
175    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
176    #[error(transparent)]
177    Other(#[from] anyhow::Error),
178}
179
180impl LanguageModelCompletionError {
181    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
182        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
183        let upstream_status = error_json
184            .get("upstream_status")
185            .and_then(|v| v.as_u64())
186            .and_then(|status| u16::try_from(status).ok())
187            .and_then(|status| StatusCode::from_u16(status).ok())?;
188        let inner_message = error_json
189            .get("message")
190            .and_then(|v| v.as_str())
191            .unwrap_or(message)
192            .to_string();
193        Some((upstream_status, inner_message))
194    }
195
196    pub fn from_cloud_failure(
197        upstream_provider: LanguageModelProviderName,
198        code: String,
199        message: String,
200        retry_after: Option<Duration>,
201    ) -> Self {
202        if let Some(tokens) = parse_prompt_too_long(&message) {
203            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
204            // to be reported. This is a temporary workaround to handle this in the case where the
205            // token limit has been exceeded.
206            Self::PromptTooLarge {
207                tokens: Some(tokens),
208            }
209        } else if code == "upstream_http_error" {
210            if let Some((upstream_status, inner_message)) =
211                Self::parse_upstream_error_json(&message)
212            {
213                return Self::from_http_status(
214                    upstream_provider,
215                    upstream_status,
216                    inner_message,
217                    retry_after,
218                );
219            }
220            anyhow!("completion request failed, code: {code}, message: {message}").into()
221        } else if let Some(status_code) = code
222            .strip_prefix("upstream_http_")
223            .and_then(|code| StatusCode::from_str(code).ok())
224        {
225            Self::from_http_status(upstream_provider, status_code, message, retry_after)
226        } else if let Some(status_code) = code
227            .strip_prefix("http_")
228            .and_then(|code| StatusCode::from_str(code).ok())
229        {
230            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
231        } else {
232            anyhow!("completion request failed, code: {code}, message: {message}").into()
233        }
234    }
235
236    pub fn from_http_status(
237        provider: LanguageModelProviderName,
238        status_code: StatusCode,
239        message: String,
240        retry_after: Option<Duration>,
241    ) -> Self {
242        match status_code {
243            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
244            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
245            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
246            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
247            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
248                tokens: parse_prompt_too_long(&message),
249            },
250            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
251                provider,
252                retry_after,
253            },
254            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
255            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
256                provider,
257                retry_after,
258            },
259            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
260                provider,
261                retry_after,
262            },
263            _ => Self::HttpResponseError {
264                provider,
265                status_code,
266                message,
267            },
268        }
269    }
270}
271
272impl From<AnthropicError> for LanguageModelCompletionError {
273    fn from(error: AnthropicError) -> Self {
274        let provider = ANTHROPIC_PROVIDER_NAME;
275        match error {
276            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
277            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
278            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
279            AnthropicError::DeserializeResponse(error) => {
280                Self::DeserializeResponse { provider, error }
281            }
282            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
283            AnthropicError::HttpResponseError {
284                status_code,
285                message,
286            } => Self::HttpResponseError {
287                provider,
288                status_code,
289                message,
290            },
291            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
292                provider,
293                retry_after: Some(retry_after),
294            },
295            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
296                provider,
297                retry_after,
298            },
299            AnthropicError::ApiError(api_error) => api_error.into(),
300        }
301    }
302}
303
304impl From<anthropic::ApiError> for LanguageModelCompletionError {
305    fn from(error: anthropic::ApiError) -> Self {
306        use anthropic::ApiErrorCode::*;
307        let provider = ANTHROPIC_PROVIDER_NAME;
308        match error.code() {
309            Some(code) => match code {
310                InvalidRequestError => Self::BadRequestFormat {
311                    provider,
312                    message: error.message,
313                },
314                AuthenticationError => Self::AuthenticationError {
315                    provider,
316                    message: error.message,
317                },
318                PermissionError => Self::PermissionError {
319                    provider,
320                    message: error.message,
321                },
322                NotFoundError => Self::ApiEndpointNotFound { provider },
323                RequestTooLarge => Self::PromptTooLarge {
324                    tokens: parse_prompt_too_long(&error.message),
325                },
326                RateLimitError => Self::RateLimitExceeded {
327                    provider,
328                    retry_after: None,
329                },
330                ApiError => Self::ApiInternalServerError {
331                    provider,
332                    message: error.message,
333                },
334                OverloadedError => Self::ServerOverloaded {
335                    provider,
336                    retry_after: None,
337                },
338            },
339            None => Self::Other(error.into()),
340        }
341    }
342}
343
344impl From<OpenRouterError> for LanguageModelCompletionError {
345    fn from(error: OpenRouterError) -> Self {
346        let provider = LanguageModelProviderName::new("OpenRouter");
347        match error {
348            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
349            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
350            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
351            OpenRouterError::DeserializeResponse(error) => {
352                Self::DeserializeResponse { provider, error }
353            }
354            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
355            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
356                provider,
357                retry_after: Some(retry_after),
358            },
359            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
360                provider,
361                retry_after,
362            },
363            OpenRouterError::ApiError(api_error) => api_error.into(),
364        }
365    }
366}
367
368impl From<open_router::ApiError> for LanguageModelCompletionError {
369    fn from(error: open_router::ApiError) -> Self {
370        use open_router::ApiErrorCode::*;
371        let provider = LanguageModelProviderName::new("OpenRouter");
372        match error.code {
373            InvalidRequestError => Self::BadRequestFormat {
374                provider,
375                message: error.message,
376            },
377            AuthenticationError => Self::AuthenticationError {
378                provider,
379                message: error.message,
380            },
381            PaymentRequiredError => Self::AuthenticationError {
382                provider,
383                message: format!("Payment required: {}", error.message),
384            },
385            PermissionError => Self::PermissionError {
386                provider,
387                message: error.message,
388            },
389            RequestTimedOut => Self::HttpResponseError {
390                provider,
391                status_code: StatusCode::REQUEST_TIMEOUT,
392                message: error.message,
393            },
394            RateLimitError => Self::RateLimitExceeded {
395                provider,
396                retry_after: None,
397            },
398            ApiError => Self::ApiInternalServerError {
399                provider,
400                message: error.message,
401            },
402            OverloadedError => Self::ServerOverloaded {
403                provider,
404                retry_after: None,
405            },
406        }
407    }
408}
409
410/// Indicates the format used to define the input schema for a language model tool.
411#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
412pub enum LanguageModelToolSchemaFormat {
413    /// A JSON schema, see https://json-schema.org
414    JsonSchema,
415    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
416    JsonSchemaSubset,
417}
418
419#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
420#[serde(rename_all = "snake_case")]
421pub enum StopReason {
422    EndTurn,
423    MaxTokens,
424    ToolUse,
425    Refusal,
426}
427
428#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
429pub struct TokenUsage {
430    #[serde(default, skip_serializing_if = "is_default")]
431    pub input_tokens: u64,
432    #[serde(default, skip_serializing_if = "is_default")]
433    pub output_tokens: u64,
434    #[serde(default, skip_serializing_if = "is_default")]
435    pub cache_creation_input_tokens: u64,
436    #[serde(default, skip_serializing_if = "is_default")]
437    pub cache_read_input_tokens: u64,
438}
439
440impl TokenUsage {
441    pub fn total_tokens(&self) -> u64 {
442        self.input_tokens
443            + self.output_tokens
444            + self.cache_read_input_tokens
445            + self.cache_creation_input_tokens
446    }
447}
448
449impl Add<TokenUsage> for TokenUsage {
450    type Output = Self;
451
452    fn add(self, other: Self) -> Self {
453        Self {
454            input_tokens: self.input_tokens + other.input_tokens,
455            output_tokens: self.output_tokens + other.output_tokens,
456            cache_creation_input_tokens: self.cache_creation_input_tokens
457                + other.cache_creation_input_tokens,
458            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
459        }
460    }
461}
462
463impl Sub<TokenUsage> for TokenUsage {
464    type Output = Self;
465
466    fn sub(self, other: Self) -> Self {
467        Self {
468            input_tokens: self.input_tokens - other.input_tokens,
469            output_tokens: self.output_tokens - other.output_tokens,
470            cache_creation_input_tokens: self.cache_creation_input_tokens
471                - other.cache_creation_input_tokens,
472            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
473        }
474    }
475}
476
477#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
478pub struct LanguageModelToolUseId(Arc<str>);
479
480impl fmt::Display for LanguageModelToolUseId {
481    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
482        write!(f, "{}", self.0)
483    }
484}
485
486impl<T> From<T> for LanguageModelToolUseId
487where
488    T: Into<Arc<str>>,
489{
490    fn from(value: T) -> Self {
491        Self(value.into())
492    }
493}
494
495#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
496pub struct LanguageModelToolUse {
497    pub id: LanguageModelToolUseId,
498    pub name: Arc<str>,
499    pub raw_input: String,
500    pub input: serde_json::Value,
501    pub is_input_complete: bool,
502}
503
504pub struct LanguageModelTextStream {
505    pub message_id: Option<String>,
506    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
507    // Has complete token usage after the stream has finished
508    pub last_token_usage: Arc<Mutex<TokenUsage>>,
509}
510
511impl Default for LanguageModelTextStream {
512    fn default() -> Self {
513        Self {
514            message_id: None,
515            stream: Box::pin(futures::stream::empty()),
516            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
517        }
518    }
519}
520
521pub trait LanguageModel: Send + Sync {
522    fn id(&self) -> LanguageModelId;
523    fn name(&self) -> LanguageModelName;
524    fn provider_id(&self) -> LanguageModelProviderId;
525    fn provider_name(&self) -> LanguageModelProviderName;
526    fn upstream_provider_id(&self) -> LanguageModelProviderId {
527        self.provider_id()
528    }
529    fn upstream_provider_name(&self) -> LanguageModelProviderName {
530        self.provider_name()
531    }
532
533    fn telemetry_id(&self) -> String;
534
535    fn api_key(&self, _cx: &App) -> Option<String> {
536        None
537    }
538
539    /// Whether this model supports images
540    fn supports_images(&self) -> bool;
541
542    /// Whether this model supports tools.
543    fn supports_tools(&self) -> bool;
544
545    /// Whether this model supports choosing which tool to use.
546    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
547
548    /// Returns whether this model supports "burn mode";
549    fn supports_burn_mode(&self) -> bool {
550        false
551    }
552
553    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
554        LanguageModelToolSchemaFormat::JsonSchema
555    }
556
557    fn max_token_count(&self) -> u64;
558    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
559    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
560        None
561    }
562    fn max_output_tokens(&self) -> Option<u64> {
563        None
564    }
565
566    fn count_tokens(
567        &self,
568        request: LanguageModelRequest,
569        cx: &App,
570    ) -> BoxFuture<'static, Result<u64>>;
571
572    fn stream_completion(
573        &self,
574        request: LanguageModelRequest,
575        cx: &AsyncApp,
576    ) -> BoxFuture<
577        'static,
578        Result<
579            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
580            LanguageModelCompletionError,
581        >,
582    >;
583
584    fn stream_completion_text(
585        &self,
586        request: LanguageModelRequest,
587        cx: &AsyncApp,
588    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
589        let future = self.stream_completion(request, cx);
590
591        async move {
592            let events = future.await?;
593            let mut events = events.fuse();
594            let mut message_id = None;
595            let mut first_item_text = None;
596            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
597
598            if let Some(first_event) = events.next().await {
599                match first_event {
600                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
601                        message_id = Some(id);
602                    }
603                    Ok(LanguageModelCompletionEvent::Text(text)) => {
604                        first_item_text = Some(text);
605                    }
606                    _ => (),
607                }
608            }
609
610            let stream = futures::stream::iter(first_item_text.map(Ok))
611                .chain(events.filter_map({
612                    let last_token_usage = last_token_usage.clone();
613                    move |result| {
614                        let last_token_usage = last_token_usage.clone();
615                        async move {
616                            match result {
617                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
618                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
619                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
620                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
621                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
622                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
623                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
624                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
625                                    ..
626                                }) => None,
627                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
628                                    *last_token_usage.lock() = token_usage;
629                                    None
630                                }
631                                Err(err) => Some(Err(err)),
632                            }
633                        }
634                    }
635                }))
636                .boxed();
637
638            Ok(LanguageModelTextStream {
639                message_id,
640                stream,
641                last_token_usage,
642            })
643        }
644        .boxed()
645    }
646
647    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
648        None
649    }
650
651    #[cfg(any(test, feature = "test-support"))]
652    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
653        unimplemented!()
654    }
655}
656
657pub trait LanguageModelExt: LanguageModel {
658    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
659        match mode {
660            CompletionMode::Normal => self.max_token_count(),
661            CompletionMode::Max => self
662                .max_token_count_in_burn_mode()
663                .unwrap_or_else(|| self.max_token_count()),
664        }
665    }
666}
667impl LanguageModelExt for dyn LanguageModel {}
668
669pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
670    fn name() -> String;
671    fn description() -> String;
672}
673
674/// An error that occurred when trying to authenticate the language model provider.
675#[derive(Debug, Error)]
676pub enum AuthenticateError {
677    #[error("connection refused")]
678    ConnectionRefused,
679    #[error("credentials not found")]
680    CredentialsNotFound,
681    #[error(transparent)]
682    Other(#[from] anyhow::Error),
683}
684
685pub trait LanguageModelProvider: 'static {
686    fn id(&self) -> LanguageModelProviderId;
687    fn name(&self) -> LanguageModelProviderName;
688    fn icon(&self) -> IconName {
689        IconName::ZedAssistant
690    }
691    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
692    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
693    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
694    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
695        Vec::new()
696    }
697    fn is_authenticated(&self, cx: &App) -> bool;
698    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
699    fn configuration_view(
700        &self,
701        target_agent: ConfigurationViewTargetAgent,
702        window: &mut Window,
703        cx: &mut App,
704    ) -> AnyView;
705    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
706}
707
708#[derive(Default, Clone)]
709pub enum ConfigurationViewTargetAgent {
710    #[default]
711    ZedAgent,
712    Other(SharedString),
713}
714
715#[derive(PartialEq, Eq)]
716pub enum LanguageModelProviderTosView {
717    /// When there are some past interactions in the Agent Panel.
718    ThreadEmptyState,
719    /// When there are no past interactions in the Agent Panel.
720    ThreadFreshStart,
721    TextThreadPopup,
722    Configuration,
723}
724
725pub trait LanguageModelProviderState: 'static {
726    type ObservableEntity;
727
728    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
729
730    fn subscribe<T: 'static>(
731        &self,
732        cx: &mut gpui::Context<T>,
733        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
734    ) -> Option<gpui::Subscription> {
735        let entity = self.observable_entity()?;
736        Some(cx.observe(&entity, move |this, _, cx| {
737            callback(this, cx);
738        }))
739    }
740}
741
742#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
743pub struct LanguageModelId(pub SharedString);
744
745#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
746pub struct LanguageModelName(pub SharedString);
747
748#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
749pub struct LanguageModelProviderId(pub SharedString);
750
751#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
752pub struct LanguageModelProviderName(pub SharedString);
753
754impl LanguageModelProviderId {
755    pub const fn new(id: &'static str) -> Self {
756        Self(SharedString::new_static(id))
757    }
758}
759
760impl LanguageModelProviderName {
761    pub const fn new(id: &'static str) -> Self {
762        Self(SharedString::new_static(id))
763    }
764}
765
766impl fmt::Display for LanguageModelProviderId {
767    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
768        write!(f, "{}", self.0)
769    }
770}
771
772impl fmt::Display for LanguageModelProviderName {
773    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
774        write!(f, "{}", self.0)
775    }
776}
777
778impl From<String> for LanguageModelId {
779    fn from(value: String) -> Self {
780        Self(SharedString::from(value))
781    }
782}
783
784impl From<String> for LanguageModelName {
785    fn from(value: String) -> Self {
786        Self(SharedString::from(value))
787    }
788}
789
790impl From<String> for LanguageModelProviderId {
791    fn from(value: String) -> Self {
792        Self(SharedString::from(value))
793    }
794}
795
796impl From<String> for LanguageModelProviderName {
797    fn from(value: String) -> Self {
798        Self(SharedString::from(value))
799    }
800}
801
802impl From<Arc<str>> for LanguageModelProviderId {
803    fn from(value: Arc<str>) -> Self {
804        Self(SharedString::from(value))
805    }
806}
807
808impl From<Arc<str>> for LanguageModelProviderName {
809    fn from(value: Arc<str>) -> Self {
810        Self(SharedString::from(value))
811    }
812}
813
814#[cfg(test)]
815mod tests {
816    use super::*;
817
818    #[test]
819    fn test_from_cloud_failure_with_upstream_http_error() {
820        let error = LanguageModelCompletionError::from_cloud_failure(
821            String::from("anthropic").into(),
822            "upstream_http_error".to_string(),
823            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
824            None,
825        );
826
827        match error {
828            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
829                assert_eq!(provider.0, "anthropic");
830            }
831            _ => panic!(
832                "Expected ServerOverloaded error for 503 status, got: {:?}",
833                error
834            ),
835        }
836
837        let error = LanguageModelCompletionError::from_cloud_failure(
838            String::from("anthropic").into(),
839            "upstream_http_error".to_string(),
840            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
841            None,
842        );
843
844        match error {
845            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
846                assert_eq!(provider.0, "anthropic");
847                assert_eq!(message, "Internal server error");
848            }
849            _ => panic!(
850                "Expected ApiInternalServerError for 500 status, got: {:?}",
851                error
852            ),
853        }
854    }
855
856    #[test]
857    fn test_from_cloud_failure_with_standard_format() {
858        let error = LanguageModelCompletionError::from_cloud_failure(
859            String::from("anthropic").into(),
860            "upstream_http_503".to_string(),
861            "Service unavailable".to_string(),
862            None,
863        );
864
865        match error {
866            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
867                assert_eq!(provider.0, "anthropic");
868            }
869            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
870        }
871    }
872
873    #[test]
874    fn test_upstream_http_error_connection_timeout() {
875        let error = LanguageModelCompletionError::from_cloud_failure(
876            String::from("anthropic").into(),
877            "upstream_http_error".to_string(),
878            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
879            None,
880        );
881
882        match error {
883            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
884                assert_eq!(provider.0, "anthropic");
885            }
886            _ => panic!(
887                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
888                error
889            ),
890        }
891
892        let error = LanguageModelCompletionError::from_cloud_failure(
893            String::from("anthropic").into(),
894            "upstream_http_error".to_string(),
895            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
896            None,
897        );
898
899        match error {
900            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
901                assert_eq!(provider.0, "anthropic");
902                assert_eq!(
903                    message,
904                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
905                );
906            }
907            _ => panic!(
908                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
909                error
910            ),
911        }
912    }
913}