1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 15use futures::FutureExt;
 16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 18use http_client::{StatusCode, http};
 19use icons::IconName;
 20use open_router::OpenRouterError;
 21use parking_lot::Mutex;
 22use serde::{Deserialize, Serialize};
 23pub use settings::LanguageModelCacheConfiguration;
 24use std::fmt::Debug;
 25use std::ops::{Add, Sub};
 26use std::str::FromStr;
 27use std::sync::Arc;
 28use std::time::Duration;
 29use std::{fmt, io};
 30use thiserror::Error;
 31use util::serde::is_default;
 32
 33pub use crate::model::*;
 34pub use crate::rate_limiter::*;
 35pub use crate::registry::*;
 36pub use crate::request::*;
 37pub use crate::role::*;
 38pub use crate::telemetry::*;
 39
 40pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 41    LanguageModelProviderId::new("anthropic");
 42pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 43    LanguageModelProviderName::new("Anthropic");
 44
 45pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 46pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 47    LanguageModelProviderName::new("Google AI");
 48
 49pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 50pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 51    LanguageModelProviderName::new("OpenAI");
 52
 53pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
 54pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
 55
 56pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 57pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 58    LanguageModelProviderName::new("Zed");
 59
 60pub fn init(client: Arc<Client>, cx: &mut App) {
 61    init_settings(cx);
 62    RefreshLlmTokenListener::register(client, cx);
 63}
 64
 65pub fn init_settings(cx: &mut App) {
 66    registry::init(cx);
 67}
 68
 69/// A completion event from a language model.
 70#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 71pub enum LanguageModelCompletionEvent {
 72    StatusUpdate(CompletionRequestStatus),
 73    Stop(StopReason),
 74    Text(String),
 75    Thinking {
 76        text: String,
 77        signature: Option<String>,
 78    },
 79    RedactedThinking {
 80        data: String,
 81    },
 82    ToolUse(LanguageModelToolUse),
 83    ToolUseJsonParseError {
 84        id: LanguageModelToolUseId,
 85        tool_name: Arc<str>,
 86        raw_input: Arc<str>,
 87        json_parse_error: String,
 88    },
 89    StartMessage {
 90        message_id: String,
 91    },
 92    UsageUpdate(TokenUsage),
 93}
 94
 95#[derive(Error, Debug)]
 96pub enum LanguageModelCompletionError {
 97    #[error("prompt too large for context window")]
 98    PromptTooLarge { tokens: Option<u64> },
 99    #[error("missing {provider} API key")]
100    NoApiKey { provider: LanguageModelProviderName },
101    #[error("{provider}'s API rate limit exceeded")]
102    RateLimitExceeded {
103        provider: LanguageModelProviderName,
104        retry_after: Option<Duration>,
105    },
106    #[error("{provider}'s API servers are overloaded right now")]
107    ServerOverloaded {
108        provider: LanguageModelProviderName,
109        retry_after: Option<Duration>,
110    },
111    #[error("{provider}'s API server reported an internal server error: {message}")]
112    ApiInternalServerError {
113        provider: LanguageModelProviderName,
114        message: String,
115    },
116    #[error("{message}")]
117    UpstreamProviderError {
118        message: String,
119        status: StatusCode,
120        retry_after: Option<Duration>,
121    },
122    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
123    HttpResponseError {
124        provider: LanguageModelProviderName,
125        status_code: StatusCode,
126        message: String,
127    },
128
129    // Client errors
130    #[error("invalid request format to {provider}'s API: {message}")]
131    BadRequestFormat {
132        provider: LanguageModelProviderName,
133        message: String,
134    },
135    #[error("authentication error with {provider}'s API: {message}")]
136    AuthenticationError {
137        provider: LanguageModelProviderName,
138        message: String,
139    },
140    #[error("permission error with {provider}'s API: {message}")]
141    PermissionError {
142        provider: LanguageModelProviderName,
143        message: String,
144    },
145    #[error("language model provider API endpoint not found")]
146    ApiEndpointNotFound { provider: LanguageModelProviderName },
147    #[error("I/O error reading response from {provider}'s API")]
148    ApiReadResponseError {
149        provider: LanguageModelProviderName,
150        #[source]
151        error: io::Error,
152    },
153    #[error("error serializing request to {provider} API")]
154    SerializeRequest {
155        provider: LanguageModelProviderName,
156        #[source]
157        error: serde_json::Error,
158    },
159    #[error("error building request body to {provider} API")]
160    BuildRequestBody {
161        provider: LanguageModelProviderName,
162        #[source]
163        error: http::Error,
164    },
165    #[error("error sending HTTP request to {provider} API")]
166    HttpSend {
167        provider: LanguageModelProviderName,
168        #[source]
169        error: anyhow::Error,
170    },
171    #[error("error deserializing {provider} API response")]
172    DeserializeResponse {
173        provider: LanguageModelProviderName,
174        #[source]
175        error: serde_json::Error,
176    },
177
178    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
179    #[error(transparent)]
180    Other(#[from] anyhow::Error),
181}
182
183impl LanguageModelCompletionError {
184    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
185        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
186        let upstream_status = error_json
187            .get("upstream_status")
188            .and_then(|v| v.as_u64())
189            .and_then(|status| u16::try_from(status).ok())
190            .and_then(|status| StatusCode::from_u16(status).ok())?;
191        let inner_message = error_json
192            .get("message")
193            .and_then(|v| v.as_str())
194            .unwrap_or(message)
195            .to_string();
196        Some((upstream_status, inner_message))
197    }
198
199    pub fn from_cloud_failure(
200        upstream_provider: LanguageModelProviderName,
201        code: String,
202        message: String,
203        retry_after: Option<Duration>,
204    ) -> Self {
205        if let Some(tokens) = parse_prompt_too_long(&message) {
206            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
207            // to be reported. This is a temporary workaround to handle this in the case where the
208            // token limit has been exceeded.
209            Self::PromptTooLarge {
210                tokens: Some(tokens),
211            }
212        } else if code == "upstream_http_error" {
213            if let Some((upstream_status, inner_message)) =
214                Self::parse_upstream_error_json(&message)
215            {
216                return Self::from_http_status(
217                    upstream_provider,
218                    upstream_status,
219                    inner_message,
220                    retry_after,
221                );
222            }
223            anyhow!("completion request failed, code: {code}, message: {message}").into()
224        } else if let Some(status_code) = code
225            .strip_prefix("upstream_http_")
226            .and_then(|code| StatusCode::from_str(code).ok())
227        {
228            Self::from_http_status(upstream_provider, status_code, message, retry_after)
229        } else if let Some(status_code) = code
230            .strip_prefix("http_")
231            .and_then(|code| StatusCode::from_str(code).ok())
232        {
233            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
234        } else {
235            anyhow!("completion request failed, code: {code}, message: {message}").into()
236        }
237    }
238
239    pub fn from_http_status(
240        provider: LanguageModelProviderName,
241        status_code: StatusCode,
242        message: String,
243        retry_after: Option<Duration>,
244    ) -> Self {
245        match status_code {
246            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
247            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
248            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
249            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
250            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
251                tokens: parse_prompt_too_long(&message),
252            },
253            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
254                provider,
255                retry_after,
256            },
257            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
258            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
259                provider,
260                retry_after,
261            },
262            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
263                provider,
264                retry_after,
265            },
266            _ => Self::HttpResponseError {
267                provider,
268                status_code,
269                message,
270            },
271        }
272    }
273}
274
275impl From<AnthropicError> for LanguageModelCompletionError {
276    fn from(error: AnthropicError) -> Self {
277        let provider = ANTHROPIC_PROVIDER_NAME;
278        match error {
279            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
280            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
281            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
282            AnthropicError::DeserializeResponse(error) => {
283                Self::DeserializeResponse { provider, error }
284            }
285            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
286            AnthropicError::HttpResponseError {
287                status_code,
288                message,
289            } => Self::HttpResponseError {
290                provider,
291                status_code,
292                message,
293            },
294            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
295                provider,
296                retry_after: Some(retry_after),
297            },
298            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
299                provider,
300                retry_after,
301            },
302            AnthropicError::ApiError(api_error) => api_error.into(),
303        }
304    }
305}
306
307impl From<anthropic::ApiError> for LanguageModelCompletionError {
308    fn from(error: anthropic::ApiError) -> Self {
309        use anthropic::ApiErrorCode::*;
310        let provider = ANTHROPIC_PROVIDER_NAME;
311        match error.code() {
312            Some(code) => match code {
313                InvalidRequestError => Self::BadRequestFormat {
314                    provider,
315                    message: error.message,
316                },
317                AuthenticationError => Self::AuthenticationError {
318                    provider,
319                    message: error.message,
320                },
321                PermissionError => Self::PermissionError {
322                    provider,
323                    message: error.message,
324                },
325                NotFoundError => Self::ApiEndpointNotFound { provider },
326                RequestTooLarge => Self::PromptTooLarge {
327                    tokens: parse_prompt_too_long(&error.message),
328                },
329                RateLimitError => Self::RateLimitExceeded {
330                    provider,
331                    retry_after: None,
332                },
333                ApiError => Self::ApiInternalServerError {
334                    provider,
335                    message: error.message,
336                },
337                OverloadedError => Self::ServerOverloaded {
338                    provider,
339                    retry_after: None,
340                },
341            },
342            None => Self::Other(error.into()),
343        }
344    }
345}
346
347impl From<OpenRouterError> for LanguageModelCompletionError {
348    fn from(error: OpenRouterError) -> Self {
349        let provider = LanguageModelProviderName::new("OpenRouter");
350        match error {
351            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
352            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
353            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
354            OpenRouterError::DeserializeResponse(error) => {
355                Self::DeserializeResponse { provider, error }
356            }
357            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
358            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
359                provider,
360                retry_after: Some(retry_after),
361            },
362            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
363                provider,
364                retry_after,
365            },
366            OpenRouterError::ApiError(api_error) => api_error.into(),
367        }
368    }
369}
370
371impl From<open_router::ApiError> for LanguageModelCompletionError {
372    fn from(error: open_router::ApiError) -> Self {
373        use open_router::ApiErrorCode::*;
374        let provider = LanguageModelProviderName::new("OpenRouter");
375        match error.code {
376            InvalidRequestError => Self::BadRequestFormat {
377                provider,
378                message: error.message,
379            },
380            AuthenticationError => Self::AuthenticationError {
381                provider,
382                message: error.message,
383            },
384            PaymentRequiredError => Self::AuthenticationError {
385                provider,
386                message: format!("Payment required: {}", error.message),
387            },
388            PermissionError => Self::PermissionError {
389                provider,
390                message: error.message,
391            },
392            RequestTimedOut => Self::HttpResponseError {
393                provider,
394                status_code: StatusCode::REQUEST_TIMEOUT,
395                message: error.message,
396            },
397            RateLimitError => Self::RateLimitExceeded {
398                provider,
399                retry_after: None,
400            },
401            ApiError => Self::ApiInternalServerError {
402                provider,
403                message: error.message,
404            },
405            OverloadedError => Self::ServerOverloaded {
406                provider,
407                retry_after: None,
408            },
409        }
410    }
411}
412
413/// Indicates the format used to define the input schema for a language model tool.
414#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
415pub enum LanguageModelToolSchemaFormat {
416    /// A JSON schema, see https://json-schema.org
417    JsonSchema,
418    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
419    JsonSchemaSubset,
420}
421
422#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
423#[serde(rename_all = "snake_case")]
424pub enum StopReason {
425    EndTurn,
426    MaxTokens,
427    ToolUse,
428    Refusal,
429}
430
431#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
432pub struct TokenUsage {
433    #[serde(default, skip_serializing_if = "is_default")]
434    pub input_tokens: u64,
435    #[serde(default, skip_serializing_if = "is_default")]
436    pub output_tokens: u64,
437    #[serde(default, skip_serializing_if = "is_default")]
438    pub cache_creation_input_tokens: u64,
439    #[serde(default, skip_serializing_if = "is_default")]
440    pub cache_read_input_tokens: u64,
441}
442
443impl TokenUsage {
444    pub fn total_tokens(&self) -> u64 {
445        self.input_tokens
446            + self.output_tokens
447            + self.cache_read_input_tokens
448            + self.cache_creation_input_tokens
449    }
450}
451
452impl Add<TokenUsage> for TokenUsage {
453    type Output = Self;
454
455    fn add(self, other: Self) -> Self {
456        Self {
457            input_tokens: self.input_tokens + other.input_tokens,
458            output_tokens: self.output_tokens + other.output_tokens,
459            cache_creation_input_tokens: self.cache_creation_input_tokens
460                + other.cache_creation_input_tokens,
461            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
462        }
463    }
464}
465
466impl Sub<TokenUsage> for TokenUsage {
467    type Output = Self;
468
469    fn sub(self, other: Self) -> Self {
470        Self {
471            input_tokens: self.input_tokens - other.input_tokens,
472            output_tokens: self.output_tokens - other.output_tokens,
473            cache_creation_input_tokens: self.cache_creation_input_tokens
474                - other.cache_creation_input_tokens,
475            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
476        }
477    }
478}
479
480#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
481pub struct LanguageModelToolUseId(Arc<str>);
482
483impl fmt::Display for LanguageModelToolUseId {
484    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
485        write!(f, "{}", self.0)
486    }
487}
488
489impl<T> From<T> for LanguageModelToolUseId
490where
491    T: Into<Arc<str>>,
492{
493    fn from(value: T) -> Self {
494        Self(value.into())
495    }
496}
497
498#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
499pub struct LanguageModelToolUse {
500    pub id: LanguageModelToolUseId,
501    pub name: Arc<str>,
502    pub raw_input: String,
503    pub input: serde_json::Value,
504    pub is_input_complete: bool,
505}
506
507pub struct LanguageModelTextStream {
508    pub message_id: Option<String>,
509    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
510    // Has complete token usage after the stream has finished
511    pub last_token_usage: Arc<Mutex<TokenUsage>>,
512}
513
514impl Debug for LanguageModelTextStream {
515    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
516        f.debug_struct("LanguageModelTextStream")
517            .field("message_id", &self.message_id)
518            .finish()
519    }
520}
521
522impl Default for LanguageModelTextStream {
523    fn default() -> Self {
524        Self {
525            message_id: None,
526            stream: Box::pin(futures::stream::empty()),
527            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
528        }
529    }
530}
531
532pub trait LanguageModel: Send + Sync {
533    fn id(&self) -> LanguageModelId;
534    fn name(&self) -> LanguageModelName;
535    fn provider_id(&self) -> LanguageModelProviderId;
536    fn provider_name(&self) -> LanguageModelProviderName;
537    fn upstream_provider_id(&self) -> LanguageModelProviderId {
538        self.provider_id()
539    }
540    fn upstream_provider_name(&self) -> LanguageModelProviderName {
541        self.provider_name()
542    }
543
544    fn telemetry_id(&self) -> String;
545
546    fn api_key(&self, _cx: &App) -> Option<String> {
547        None
548    }
549
550    /// Whether this model supports images
551    fn supports_images(&self) -> bool;
552
553    /// Whether this model supports tools.
554    fn supports_tools(&self) -> bool;
555
556    /// Whether this model supports choosing which tool to use.
557    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
558
559    /// Returns whether this model supports "burn mode";
560    fn supports_burn_mode(&self) -> bool {
561        false
562    }
563
564    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
565        LanguageModelToolSchemaFormat::JsonSchema
566    }
567
568    fn max_token_count(&self) -> u64;
569    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
570    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
571        None
572    }
573    fn max_output_tokens(&self) -> Option<u64> {
574        None
575    }
576
577    fn count_tokens(
578        &self,
579        request: LanguageModelRequest,
580        cx: &App,
581    ) -> BoxFuture<'static, Result<u64>>;
582
583    fn stream_completion(
584        &self,
585        request: LanguageModelRequest,
586        cx: &AsyncApp,
587    ) -> BoxFuture<
588        'static,
589        Result<
590            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
591            LanguageModelCompletionError,
592        >,
593    >;
594
595    fn stream_completion_text(
596        &self,
597        request: LanguageModelRequest,
598        cx: &AsyncApp,
599    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
600        let future = self.stream_completion(request, cx);
601
602        async move {
603            let events = future.await?;
604            let mut events = events.fuse();
605            let mut message_id = None;
606            let mut first_item_text = None;
607            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
608
609            if let Some(first_event) = events.next().await {
610                match first_event {
611                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
612                        message_id = Some(id);
613                    }
614                    Ok(LanguageModelCompletionEvent::Text(text)) => {
615                        first_item_text = Some(text);
616                    }
617                    _ => (),
618                }
619            }
620
621            let stream = futures::stream::iter(first_item_text.map(Ok))
622                .chain(events.filter_map({
623                    let last_token_usage = last_token_usage.clone();
624                    move |result| {
625                        let last_token_usage = last_token_usage.clone();
626                        async move {
627                            match result {
628                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
629                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
630                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
631                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
632                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
633                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
634                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
635                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
636                                    ..
637                                }) => None,
638                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
639                                    *last_token_usage.lock() = token_usage;
640                                    None
641                                }
642                                Err(err) => Some(Err(err)),
643                            }
644                        }
645                    }
646                }))
647                .boxed();
648
649            Ok(LanguageModelTextStream {
650                message_id,
651                stream,
652                last_token_usage,
653            })
654        }
655        .boxed()
656    }
657
658    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
659        None
660    }
661
662    #[cfg(any(test, feature = "test-support"))]
663    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
664        unimplemented!()
665    }
666}
667
668pub trait LanguageModelExt: LanguageModel {
669    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
670        match mode {
671            CompletionMode::Normal => self.max_token_count(),
672            CompletionMode::Max => self
673                .max_token_count_in_burn_mode()
674                .unwrap_or_else(|| self.max_token_count()),
675        }
676    }
677}
678impl LanguageModelExt for dyn LanguageModel {}
679
680/// An error that occurred when trying to authenticate the language model provider.
681#[derive(Debug, Error)]
682pub enum AuthenticateError {
683    #[error("connection refused")]
684    ConnectionRefused,
685    #[error("credentials not found")]
686    CredentialsNotFound,
687    #[error(transparent)]
688    Other(#[from] anyhow::Error),
689}
690
691pub trait LanguageModelProvider: 'static {
692    fn id(&self) -> LanguageModelProviderId;
693    fn name(&self) -> LanguageModelProviderName;
694    fn icon(&self) -> IconName {
695        IconName::ZedAssistant
696    }
697    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
698    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
699    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
700    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
701        Vec::new()
702    }
703    fn is_authenticated(&self, cx: &App) -> bool;
704    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
705    fn configuration_view(
706        &self,
707        target_agent: ConfigurationViewTargetAgent,
708        window: &mut Window,
709        cx: &mut App,
710    ) -> AnyView;
711    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
712}
713
714#[derive(Default, Clone)]
715pub enum ConfigurationViewTargetAgent {
716    #[default]
717    ZedAgent,
718    Other(SharedString),
719}
720
721#[derive(PartialEq, Eq)]
722pub enum LanguageModelProviderTosView {
723    /// When there are some past interactions in the Agent Panel.
724    ThreadEmptyState,
725    /// When there are no past interactions in the Agent Panel.
726    ThreadFreshStart,
727    TextThreadPopup,
728    Configuration,
729}
730
731pub trait LanguageModelProviderState: 'static {
732    type ObservableEntity;
733
734    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
735
736    fn subscribe<T: 'static>(
737        &self,
738        cx: &mut gpui::Context<T>,
739        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
740    ) -> Option<gpui::Subscription> {
741        let entity = self.observable_entity()?;
742        Some(cx.observe(&entity, move |this, _, cx| {
743            callback(this, cx);
744        }))
745    }
746}
747
748#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
749pub struct LanguageModelId(pub SharedString);
750
751#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
752pub struct LanguageModelName(pub SharedString);
753
754#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
755pub struct LanguageModelProviderId(pub SharedString);
756
757#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
758pub struct LanguageModelProviderName(pub SharedString);
759
760impl LanguageModelProviderId {
761    pub const fn new(id: &'static str) -> Self {
762        Self(SharedString::new_static(id))
763    }
764}
765
766impl LanguageModelProviderName {
767    pub const fn new(id: &'static str) -> Self {
768        Self(SharedString::new_static(id))
769    }
770}
771
772impl fmt::Display for LanguageModelProviderId {
773    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
774        write!(f, "{}", self.0)
775    }
776}
777
778impl fmt::Display for LanguageModelProviderName {
779    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
780        write!(f, "{}", self.0)
781    }
782}
783
784impl From<String> for LanguageModelId {
785    fn from(value: String) -> Self {
786        Self(SharedString::from(value))
787    }
788}
789
790impl From<String> for LanguageModelName {
791    fn from(value: String) -> Self {
792        Self(SharedString::from(value))
793    }
794}
795
796impl From<String> for LanguageModelProviderId {
797    fn from(value: String) -> Self {
798        Self(SharedString::from(value))
799    }
800}
801
802impl From<String> for LanguageModelProviderName {
803    fn from(value: String) -> Self {
804        Self(SharedString::from(value))
805    }
806}
807
808impl From<Arc<str>> for LanguageModelProviderId {
809    fn from(value: Arc<str>) -> Self {
810        Self(SharedString::from(value))
811    }
812}
813
814impl From<Arc<str>> for LanguageModelProviderName {
815    fn from(value: Arc<str>) -> Self {
816        Self(SharedString::from(value))
817    }
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823
824    #[test]
825    fn test_from_cloud_failure_with_upstream_http_error() {
826        let error = LanguageModelCompletionError::from_cloud_failure(
827            String::from("anthropic").into(),
828            "upstream_http_error".to_string(),
829            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
830            None,
831        );
832
833        match error {
834            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
835                assert_eq!(provider.0, "anthropic");
836            }
837            _ => panic!(
838                "Expected ServerOverloaded error for 503 status, got: {:?}",
839                error
840            ),
841        }
842
843        let error = LanguageModelCompletionError::from_cloud_failure(
844            String::from("anthropic").into(),
845            "upstream_http_error".to_string(),
846            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
847            None,
848        );
849
850        match error {
851            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
852                assert_eq!(provider.0, "anthropic");
853                assert_eq!(message, "Internal server error");
854            }
855            _ => panic!(
856                "Expected ApiInternalServerError for 500 status, got: {:?}",
857                error
858            ),
859        }
860    }
861
862    #[test]
863    fn test_from_cloud_failure_with_standard_format() {
864        let error = LanguageModelCompletionError::from_cloud_failure(
865            String::from("anthropic").into(),
866            "upstream_http_503".to_string(),
867            "Service unavailable".to_string(),
868            None,
869        );
870
871        match error {
872            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
873                assert_eq!(provider.0, "anthropic");
874            }
875            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
876        }
877    }
878
879    #[test]
880    fn test_upstream_http_error_connection_timeout() {
881        let error = LanguageModelCompletionError::from_cloud_failure(
882            String::from("anthropic").into(),
883            "upstream_http_error".to_string(),
884            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
885            None,
886        );
887
888        match error {
889            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
890                assert_eq!(provider.0, "anthropic");
891            }
892            _ => panic!(
893                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
894                error
895            ),
896        }
897
898        let error = LanguageModelCompletionError::from_cloud_failure(
899            String::from("anthropic").into(),
900            "upstream_http_error".to_string(),
901            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
902            None,
903        );
904
905        match error {
906            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
907                assert_eq!(provider.0, "anthropic");
908                assert_eq!(
909                    message,
910                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
911                );
912            }
913            _ => panic!(
914                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
915                error
916            ),
917        }
918    }
919}