1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 15use futures::FutureExt;
 16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 18use http_client::{StatusCode, http};
 19use icons::IconName;
 20use open_router::OpenRouterError;
 21use parking_lot::Mutex;
 22use serde::{Deserialize, Serialize};
 23pub use settings::LanguageModelCacheConfiguration;
 24use std::ops::{Add, Sub};
 25use std::str::FromStr;
 26use std::sync::Arc;
 27use std::time::Duration;
 28use std::{fmt, io};
 29use thiserror::Error;
 30use util::serde::is_default;
 31
 32pub use crate::model::*;
 33pub use crate::rate_limiter::*;
 34pub use crate::registry::*;
 35pub use crate::request::*;
 36pub use crate::role::*;
 37pub use crate::telemetry::*;
 38
 39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 40    LanguageModelProviderId::new("anthropic");
 41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 42    LanguageModelProviderName::new("Anthropic");
 43
 44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 46    LanguageModelProviderName::new("Google AI");
 47
 48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 50    LanguageModelProviderName::new("OpenAI");
 51
 52pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
 53pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
 54
 55pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 56pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 57    LanguageModelProviderName::new("Zed");
 58
 59pub fn init(client: Arc<Client>, cx: &mut App) {
 60    init_settings(cx);
 61    RefreshLlmTokenListener::register(client, cx);
 62}
 63
 64pub fn init_settings(cx: &mut App) {
 65    registry::init(cx);
 66}
 67
 68/// A completion event from a language model.
 69#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 70pub enum LanguageModelCompletionEvent {
 71    StatusUpdate(CompletionRequestStatus),
 72    Stop(StopReason),
 73    Text(String),
 74    Thinking {
 75        text: String,
 76        signature: Option<String>,
 77    },
 78    RedactedThinking {
 79        data: String,
 80    },
 81    ToolUse(LanguageModelToolUse),
 82    ToolUseJsonParseError {
 83        id: LanguageModelToolUseId,
 84        tool_name: Arc<str>,
 85        raw_input: Arc<str>,
 86        json_parse_error: String,
 87    },
 88    StartMessage {
 89        message_id: String,
 90    },
 91    UsageUpdate(TokenUsage),
 92}
 93
 94#[derive(Error, Debug)]
 95pub enum LanguageModelCompletionError {
 96    #[error("prompt too large for context window")]
 97    PromptTooLarge { tokens: Option<u64> },
 98    #[error("missing {provider} API key")]
 99    NoApiKey { provider: LanguageModelProviderName },
100    #[error("{provider}'s API rate limit exceeded")]
101    RateLimitExceeded {
102        provider: LanguageModelProviderName,
103        retry_after: Option<Duration>,
104    },
105    #[error("{provider}'s API servers are overloaded right now")]
106    ServerOverloaded {
107        provider: LanguageModelProviderName,
108        retry_after: Option<Duration>,
109    },
110    #[error("{provider}'s API server reported an internal server error: {message}")]
111    ApiInternalServerError {
112        provider: LanguageModelProviderName,
113        message: String,
114    },
115    #[error("{message}")]
116    UpstreamProviderError {
117        message: String,
118        status: StatusCode,
119        retry_after: Option<Duration>,
120    },
121    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
122    HttpResponseError {
123        provider: LanguageModelProviderName,
124        status_code: StatusCode,
125        message: String,
126    },
127
128    // Client errors
129    #[error("invalid request format to {provider}'s API: {message}")]
130    BadRequestFormat {
131        provider: LanguageModelProviderName,
132        message: String,
133    },
134    #[error("authentication error with {provider}'s API: {message}")]
135    AuthenticationError {
136        provider: LanguageModelProviderName,
137        message: String,
138    },
139    #[error("permission error with {provider}'s API: {message}")]
140    PermissionError {
141        provider: LanguageModelProviderName,
142        message: String,
143    },
144    #[error("language model provider API endpoint not found")]
145    ApiEndpointNotFound { provider: LanguageModelProviderName },
146    #[error("I/O error reading response from {provider}'s API")]
147    ApiReadResponseError {
148        provider: LanguageModelProviderName,
149        #[source]
150        error: io::Error,
151    },
152    #[error("error serializing request to {provider} API")]
153    SerializeRequest {
154        provider: LanguageModelProviderName,
155        #[source]
156        error: serde_json::Error,
157    },
158    #[error("error building request body to {provider} API")]
159    BuildRequestBody {
160        provider: LanguageModelProviderName,
161        #[source]
162        error: http::Error,
163    },
164    #[error("error sending HTTP request to {provider} API")]
165    HttpSend {
166        provider: LanguageModelProviderName,
167        #[source]
168        error: anyhow::Error,
169    },
170    #[error("error deserializing {provider} API response")]
171    DeserializeResponse {
172        provider: LanguageModelProviderName,
173        #[source]
174        error: serde_json::Error,
175    },
176
177    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
178    #[error(transparent)]
179    Other(#[from] anyhow::Error),
180}
181
182impl LanguageModelCompletionError {
183    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
184        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
185        let upstream_status = error_json
186            .get("upstream_status")
187            .and_then(|v| v.as_u64())
188            .and_then(|status| u16::try_from(status).ok())
189            .and_then(|status| StatusCode::from_u16(status).ok())?;
190        let inner_message = error_json
191            .get("message")
192            .and_then(|v| v.as_str())
193            .unwrap_or(message)
194            .to_string();
195        Some((upstream_status, inner_message))
196    }
197
198    pub fn from_cloud_failure(
199        upstream_provider: LanguageModelProviderName,
200        code: String,
201        message: String,
202        retry_after: Option<Duration>,
203    ) -> Self {
204        if let Some(tokens) = parse_prompt_too_long(&message) {
205            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
206            // to be reported. This is a temporary workaround to handle this in the case where the
207            // token limit has been exceeded.
208            Self::PromptTooLarge {
209                tokens: Some(tokens),
210            }
211        } else if code == "upstream_http_error" {
212            if let Some((upstream_status, inner_message)) =
213                Self::parse_upstream_error_json(&message)
214            {
215                return Self::from_http_status(
216                    upstream_provider,
217                    upstream_status,
218                    inner_message,
219                    retry_after,
220                );
221            }
222            anyhow!("completion request failed, code: {code}, message: {message}").into()
223        } else if let Some(status_code) = code
224            .strip_prefix("upstream_http_")
225            .and_then(|code| StatusCode::from_str(code).ok())
226        {
227            Self::from_http_status(upstream_provider, status_code, message, retry_after)
228        } else if let Some(status_code) = code
229            .strip_prefix("http_")
230            .and_then(|code| StatusCode::from_str(code).ok())
231        {
232            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
233        } else {
234            anyhow!("completion request failed, code: {code}, message: {message}").into()
235        }
236    }
237
238    pub fn from_http_status(
239        provider: LanguageModelProviderName,
240        status_code: StatusCode,
241        message: String,
242        retry_after: Option<Duration>,
243    ) -> Self {
244        match status_code {
245            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
246            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
247            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
248            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
249            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
250                tokens: parse_prompt_too_long(&message),
251            },
252            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
253                provider,
254                retry_after,
255            },
256            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
257            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
258                provider,
259                retry_after,
260            },
261            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
262                provider,
263                retry_after,
264            },
265            _ => Self::HttpResponseError {
266                provider,
267                status_code,
268                message,
269            },
270        }
271    }
272}
273
274impl From<AnthropicError> for LanguageModelCompletionError {
275    fn from(error: AnthropicError) -> Self {
276        let provider = ANTHROPIC_PROVIDER_NAME;
277        match error {
278            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
279            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
280            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
281            AnthropicError::DeserializeResponse(error) => {
282                Self::DeserializeResponse { provider, error }
283            }
284            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
285            AnthropicError::HttpResponseError {
286                status_code,
287                message,
288            } => Self::HttpResponseError {
289                provider,
290                status_code,
291                message,
292            },
293            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
294                provider,
295                retry_after: Some(retry_after),
296            },
297            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
298                provider,
299                retry_after,
300            },
301            AnthropicError::ApiError(api_error) => api_error.into(),
302        }
303    }
304}
305
306impl From<anthropic::ApiError> for LanguageModelCompletionError {
307    fn from(error: anthropic::ApiError) -> Self {
308        use anthropic::ApiErrorCode::*;
309        let provider = ANTHROPIC_PROVIDER_NAME;
310        match error.code() {
311            Some(code) => match code {
312                InvalidRequestError => Self::BadRequestFormat {
313                    provider,
314                    message: error.message,
315                },
316                AuthenticationError => Self::AuthenticationError {
317                    provider,
318                    message: error.message,
319                },
320                PermissionError => Self::PermissionError {
321                    provider,
322                    message: error.message,
323                },
324                NotFoundError => Self::ApiEndpointNotFound { provider },
325                RequestTooLarge => Self::PromptTooLarge {
326                    tokens: parse_prompt_too_long(&error.message),
327                },
328                RateLimitError => Self::RateLimitExceeded {
329                    provider,
330                    retry_after: None,
331                },
332                ApiError => Self::ApiInternalServerError {
333                    provider,
334                    message: error.message,
335                },
336                OverloadedError => Self::ServerOverloaded {
337                    provider,
338                    retry_after: None,
339                },
340            },
341            None => Self::Other(error.into()),
342        }
343    }
344}
345
346impl From<OpenRouterError> for LanguageModelCompletionError {
347    fn from(error: OpenRouterError) -> Self {
348        let provider = LanguageModelProviderName::new("OpenRouter");
349        match error {
350            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
351            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
352            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
353            OpenRouterError::DeserializeResponse(error) => {
354                Self::DeserializeResponse { provider, error }
355            }
356            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
357            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
358                provider,
359                retry_after: Some(retry_after),
360            },
361            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
362                provider,
363                retry_after,
364            },
365            OpenRouterError::ApiError(api_error) => api_error.into(),
366        }
367    }
368}
369
370impl From<open_router::ApiError> for LanguageModelCompletionError {
371    fn from(error: open_router::ApiError) -> Self {
372        use open_router::ApiErrorCode::*;
373        let provider = LanguageModelProviderName::new("OpenRouter");
374        match error.code {
375            InvalidRequestError => Self::BadRequestFormat {
376                provider,
377                message: error.message,
378            },
379            AuthenticationError => Self::AuthenticationError {
380                provider,
381                message: error.message,
382            },
383            PaymentRequiredError => Self::AuthenticationError {
384                provider,
385                message: format!("Payment required: {}", error.message),
386            },
387            PermissionError => Self::PermissionError {
388                provider,
389                message: error.message,
390            },
391            RequestTimedOut => Self::HttpResponseError {
392                provider,
393                status_code: StatusCode::REQUEST_TIMEOUT,
394                message: error.message,
395            },
396            RateLimitError => Self::RateLimitExceeded {
397                provider,
398                retry_after: None,
399            },
400            ApiError => Self::ApiInternalServerError {
401                provider,
402                message: error.message,
403            },
404            OverloadedError => Self::ServerOverloaded {
405                provider,
406                retry_after: None,
407            },
408        }
409    }
410}
411
412/// Indicates the format used to define the input schema for a language model tool.
413#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
414pub enum LanguageModelToolSchemaFormat {
415    /// A JSON schema, see https://json-schema.org
416    JsonSchema,
417    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
418    JsonSchemaSubset,
419}
420
421#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
422#[serde(rename_all = "snake_case")]
423pub enum StopReason {
424    EndTurn,
425    MaxTokens,
426    ToolUse,
427    Refusal,
428}
429
430#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
431pub struct TokenUsage {
432    #[serde(default, skip_serializing_if = "is_default")]
433    pub input_tokens: u64,
434    #[serde(default, skip_serializing_if = "is_default")]
435    pub output_tokens: u64,
436    #[serde(default, skip_serializing_if = "is_default")]
437    pub cache_creation_input_tokens: u64,
438    #[serde(default, skip_serializing_if = "is_default")]
439    pub cache_read_input_tokens: u64,
440}
441
442impl TokenUsage {
443    pub fn total_tokens(&self) -> u64 {
444        self.input_tokens
445            + self.output_tokens
446            + self.cache_read_input_tokens
447            + self.cache_creation_input_tokens
448    }
449}
450
451impl Add<TokenUsage> for TokenUsage {
452    type Output = Self;
453
454    fn add(self, other: Self) -> Self {
455        Self {
456            input_tokens: self.input_tokens + other.input_tokens,
457            output_tokens: self.output_tokens + other.output_tokens,
458            cache_creation_input_tokens: self.cache_creation_input_tokens
459                + other.cache_creation_input_tokens,
460            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
461        }
462    }
463}
464
465impl Sub<TokenUsage> for TokenUsage {
466    type Output = Self;
467
468    fn sub(self, other: Self) -> Self {
469        Self {
470            input_tokens: self.input_tokens - other.input_tokens,
471            output_tokens: self.output_tokens - other.output_tokens,
472            cache_creation_input_tokens: self.cache_creation_input_tokens
473                - other.cache_creation_input_tokens,
474            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
475        }
476    }
477}
478
479#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
480pub struct LanguageModelToolUseId(Arc<str>);
481
482impl fmt::Display for LanguageModelToolUseId {
483    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
484        write!(f, "{}", self.0)
485    }
486}
487
488impl<T> From<T> for LanguageModelToolUseId
489where
490    T: Into<Arc<str>>,
491{
492    fn from(value: T) -> Self {
493        Self(value.into())
494    }
495}
496
497#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
498pub struct LanguageModelToolUse {
499    pub id: LanguageModelToolUseId,
500    pub name: Arc<str>,
501    pub raw_input: String,
502    pub input: serde_json::Value,
503    pub is_input_complete: bool,
504}
505
506pub struct LanguageModelTextStream {
507    pub message_id: Option<String>,
508    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
509    // Has complete token usage after the stream has finished
510    pub last_token_usage: Arc<Mutex<TokenUsage>>,
511}
512
513impl Default for LanguageModelTextStream {
514    fn default() -> Self {
515        Self {
516            message_id: None,
517            stream: Box::pin(futures::stream::empty()),
518            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
519        }
520    }
521}
522
523pub trait LanguageModel: Send + Sync {
524    fn id(&self) -> LanguageModelId;
525    fn name(&self) -> LanguageModelName;
526    fn provider_id(&self) -> LanguageModelProviderId;
527    fn provider_name(&self) -> LanguageModelProviderName;
528    fn upstream_provider_id(&self) -> LanguageModelProviderId {
529        self.provider_id()
530    }
531    fn upstream_provider_name(&self) -> LanguageModelProviderName {
532        self.provider_name()
533    }
534
535    fn telemetry_id(&self) -> String;
536
537    fn api_key(&self, _cx: &App) -> Option<String> {
538        None
539    }
540
541    /// Whether this model supports images
542    fn supports_images(&self) -> bool;
543
544    /// Whether this model supports tools.
545    fn supports_tools(&self) -> bool;
546
547    /// Whether this model supports choosing which tool to use.
548    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
549
550    /// Returns whether this model supports "burn mode";
551    fn supports_burn_mode(&self) -> bool {
552        false
553    }
554
555    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
556        LanguageModelToolSchemaFormat::JsonSchema
557    }
558
559    fn max_token_count(&self) -> u64;
560    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
561    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
562        None
563    }
564    fn max_output_tokens(&self) -> Option<u64> {
565        None
566    }
567
568    fn count_tokens(
569        &self,
570        request: LanguageModelRequest,
571        cx: &App,
572    ) -> BoxFuture<'static, Result<u64>>;
573
574    fn stream_completion(
575        &self,
576        request: LanguageModelRequest,
577        cx: &AsyncApp,
578    ) -> BoxFuture<
579        'static,
580        Result<
581            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
582            LanguageModelCompletionError,
583        >,
584    >;
585
586    fn stream_completion_text(
587        &self,
588        request: LanguageModelRequest,
589        cx: &AsyncApp,
590    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
591        let future = self.stream_completion(request, cx);
592
593        async move {
594            let events = future.await?;
595            let mut events = events.fuse();
596            let mut message_id = None;
597            let mut first_item_text = None;
598            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
599
600            if let Some(first_event) = events.next().await {
601                match first_event {
602                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
603                        message_id = Some(id);
604                    }
605                    Ok(LanguageModelCompletionEvent::Text(text)) => {
606                        first_item_text = Some(text);
607                    }
608                    _ => (),
609                }
610            }
611
612            let stream = futures::stream::iter(first_item_text.map(Ok))
613                .chain(events.filter_map({
614                    let last_token_usage = last_token_usage.clone();
615                    move |result| {
616                        let last_token_usage = last_token_usage.clone();
617                        async move {
618                            match result {
619                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
620                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
621                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
622                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
623                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
624                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
625                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
626                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
627                                    ..
628                                }) => None,
629                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
630                                    *last_token_usage.lock() = token_usage;
631                                    None
632                                }
633                                Err(err) => Some(Err(err)),
634                            }
635                        }
636                    }
637                }))
638                .boxed();
639
640            Ok(LanguageModelTextStream {
641                message_id,
642                stream,
643                last_token_usage,
644            })
645        }
646        .boxed()
647    }
648
649    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
650        None
651    }
652
653    #[cfg(any(test, feature = "test-support"))]
654    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
655        unimplemented!()
656    }
657}
658
659pub trait LanguageModelExt: LanguageModel {
660    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
661        match mode {
662            CompletionMode::Normal => self.max_token_count(),
663            CompletionMode::Max => self
664                .max_token_count_in_burn_mode()
665                .unwrap_or_else(|| self.max_token_count()),
666        }
667    }
668}
669impl LanguageModelExt for dyn LanguageModel {}
670
671/// An error that occurred when trying to authenticate the language model provider.
672#[derive(Debug, Error)]
673pub enum AuthenticateError {
674    #[error("connection refused")]
675    ConnectionRefused,
676    #[error("credentials not found")]
677    CredentialsNotFound,
678    #[error(transparent)]
679    Other(#[from] anyhow::Error),
680}
681
682pub trait LanguageModelProvider: 'static {
683    fn id(&self) -> LanguageModelProviderId;
684    fn name(&self) -> LanguageModelProviderName;
685    fn icon(&self) -> IconName {
686        IconName::ZedAssistant
687    }
688    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
689    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
690    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
691    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
692        Vec::new()
693    }
694    fn is_authenticated(&self, cx: &App) -> bool;
695    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
696    fn configuration_view(
697        &self,
698        target_agent: ConfigurationViewTargetAgent,
699        window: &mut Window,
700        cx: &mut App,
701    ) -> AnyView;
702    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
703}
704
705#[derive(Default, Clone)]
706pub enum ConfigurationViewTargetAgent {
707    #[default]
708    ZedAgent,
709    Other(SharedString),
710}
711
712#[derive(PartialEq, Eq)]
713pub enum LanguageModelProviderTosView {
714    /// When there are some past interactions in the Agent Panel.
715    ThreadEmptyState,
716    /// When there are no past interactions in the Agent Panel.
717    ThreadFreshStart,
718    TextThreadPopup,
719    Configuration,
720}
721
722pub trait LanguageModelProviderState: 'static {
723    type ObservableEntity;
724
725    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
726
727    fn subscribe<T: 'static>(
728        &self,
729        cx: &mut gpui::Context<T>,
730        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
731    ) -> Option<gpui::Subscription> {
732        let entity = self.observable_entity()?;
733        Some(cx.observe(&entity, move |this, _, cx| {
734            callback(this, cx);
735        }))
736    }
737}
738
739#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
740pub struct LanguageModelId(pub SharedString);
741
742#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
743pub struct LanguageModelName(pub SharedString);
744
745#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
746pub struct LanguageModelProviderId(pub SharedString);
747
748#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
749pub struct LanguageModelProviderName(pub SharedString);
750
751impl LanguageModelProviderId {
752    pub const fn new(id: &'static str) -> Self {
753        Self(SharedString::new_static(id))
754    }
755}
756
757impl LanguageModelProviderName {
758    pub const fn new(id: &'static str) -> Self {
759        Self(SharedString::new_static(id))
760    }
761}
762
763impl fmt::Display for LanguageModelProviderId {
764    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
765        write!(f, "{}", self.0)
766    }
767}
768
769impl fmt::Display for LanguageModelProviderName {
770    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
771        write!(f, "{}", self.0)
772    }
773}
774
775impl From<String> for LanguageModelId {
776    fn from(value: String) -> Self {
777        Self(SharedString::from(value))
778    }
779}
780
781impl From<String> for LanguageModelName {
782    fn from(value: String) -> Self {
783        Self(SharedString::from(value))
784    }
785}
786
787impl From<String> for LanguageModelProviderId {
788    fn from(value: String) -> Self {
789        Self(SharedString::from(value))
790    }
791}
792
793impl From<String> for LanguageModelProviderName {
794    fn from(value: String) -> Self {
795        Self(SharedString::from(value))
796    }
797}
798
799impl From<Arc<str>> for LanguageModelProviderId {
800    fn from(value: Arc<str>) -> Self {
801        Self(SharedString::from(value))
802    }
803}
804
805impl From<Arc<str>> for LanguageModelProviderName {
806    fn from(value: Arc<str>) -> Self {
807        Self(SharedString::from(value))
808    }
809}
810
811#[cfg(test)]
812mod tests {
813    use super::*;
814
815    #[test]
816    fn test_from_cloud_failure_with_upstream_http_error() {
817        let error = LanguageModelCompletionError::from_cloud_failure(
818            String::from("anthropic").into(),
819            "upstream_http_error".to_string(),
820            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
821            None,
822        );
823
824        match error {
825            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
826                assert_eq!(provider.0, "anthropic");
827            }
828            _ => panic!(
829                "Expected ServerOverloaded error for 503 status, got: {:?}",
830                error
831            ),
832        }
833
834        let error = LanguageModelCompletionError::from_cloud_failure(
835            String::from("anthropic").into(),
836            "upstream_http_error".to_string(),
837            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
838            None,
839        );
840
841        match error {
842            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
843                assert_eq!(provider.0, "anthropic");
844                assert_eq!(message, "Internal server error");
845            }
846            _ => panic!(
847                "Expected ApiInternalServerError for 500 status, got: {:?}",
848                error
849            ),
850        }
851    }
852
853    #[test]
854    fn test_from_cloud_failure_with_standard_format() {
855        let error = LanguageModelCompletionError::from_cloud_failure(
856            String::from("anthropic").into(),
857            "upstream_http_503".to_string(),
858            "Service unavailable".to_string(),
859            None,
860        );
861
862        match error {
863            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
864                assert_eq!(provider.0, "anthropic");
865            }
866            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
867        }
868    }
869
870    #[test]
871    fn test_upstream_http_error_connection_timeout() {
872        let error = LanguageModelCompletionError::from_cloud_failure(
873            String::from("anthropic").into(),
874            "upstream_http_error".to_string(),
875            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
876            None,
877        );
878
879        match error {
880            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
881                assert_eq!(provider.0, "anthropic");
882            }
883            _ => panic!(
884                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
885                error
886            ),
887        }
888
889        let error = LanguageModelCompletionError::from_cloud_failure(
890            String::from("anthropic").into(),
891            "upstream_http_error".to_string(),
892            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
893            None,
894        );
895
896        match error {
897            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
898                assert_eq!(provider.0, "anthropic");
899                assert_eq!(
900                    message,
901                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
902                );
903            }
904            _ => panic!(
905                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
906                error
907            ),
908        }
909    }
910}