language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7pub mod tool_schema;
  8
  9#[cfg(any(test, feature = "test-support"))]
 10pub mod fake_provider;
 11
 12use anthropic::{AnthropicError, parse_prompt_too_long};
 13use anyhow::{Result, anyhow};
 14use client::Client;
 15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 16use futures::FutureExt;
 17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 19use http_client::{StatusCode, http};
 20use icons::IconName;
 21use open_router::OpenRouterError;
 22use parking_lot::Mutex;
 23use serde::{Deserialize, Serialize};
 24pub use settings::LanguageModelCacheConfiguration;
 25use std::ops::{Add, Sub};
 26use std::str::FromStr;
 27use std::sync::Arc;
 28use std::time::Duration;
 29use std::{fmt, io};
 30use thiserror::Error;
 31use util::serde::is_default;
 32
 33pub use crate::model::*;
 34pub use crate::rate_limiter::*;
 35pub use crate::registry::*;
 36pub use crate::request::*;
 37pub use crate::role::*;
 38pub use crate::telemetry::*;
 39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
 40
 41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 42    LanguageModelProviderId::new("anthropic");
 43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 44    LanguageModelProviderName::new("Anthropic");
 45
 46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 48    LanguageModelProviderName::new("Google AI");
 49
 50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 52    LanguageModelProviderName::new("OpenAI");
 53
 54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
 55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
 56
 57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 59    LanguageModelProviderName::new("Zed");
 60
 61pub fn init(client: Arc<Client>, cx: &mut App) {
 62    init_settings(cx);
 63    RefreshLlmTokenListener::register(client, cx);
 64}
 65
 66pub fn init_settings(cx: &mut App) {
 67    registry::init(cx);
 68}
 69
 70/// A completion event from a language model.
 71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 72pub enum LanguageModelCompletionEvent {
 73    StatusUpdate(CompletionRequestStatus),
 74    Stop(StopReason),
 75    Text(String),
 76    Thinking {
 77        text: String,
 78        signature: Option<String>,
 79    },
 80    RedactedThinking {
 81        data: String,
 82    },
 83    ToolUse(LanguageModelToolUse),
 84    ToolUseJsonParseError {
 85        id: LanguageModelToolUseId,
 86        tool_name: Arc<str>,
 87        raw_input: Arc<str>,
 88        json_parse_error: String,
 89    },
 90    StartMessage {
 91        message_id: String,
 92    },
 93    ReasoningDetails(serde_json::Value),
 94    UsageUpdate(TokenUsage),
 95}
 96
 97#[derive(Error, Debug)]
 98pub enum LanguageModelCompletionError {
 99    #[error("prompt too large for context window")]
100    PromptTooLarge { tokens: Option<u64> },
101    #[error("missing {provider} API key")]
102    NoApiKey { provider: LanguageModelProviderName },
103    #[error("{provider}'s API rate limit exceeded")]
104    RateLimitExceeded {
105        provider: LanguageModelProviderName,
106        retry_after: Option<Duration>,
107    },
108    #[error("{provider}'s API servers are overloaded right now")]
109    ServerOverloaded {
110        provider: LanguageModelProviderName,
111        retry_after: Option<Duration>,
112    },
113    #[error("{provider}'s API server reported an internal server error: {message}")]
114    ApiInternalServerError {
115        provider: LanguageModelProviderName,
116        message: String,
117    },
118    #[error("{message}")]
119    UpstreamProviderError {
120        message: String,
121        status: StatusCode,
122        retry_after: Option<Duration>,
123    },
124    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
125    HttpResponseError {
126        provider: LanguageModelProviderName,
127        status_code: StatusCode,
128        message: String,
129    },
130
131    // Client errors
132    #[error("invalid request format to {provider}'s API: {message}")]
133    BadRequestFormat {
134        provider: LanguageModelProviderName,
135        message: String,
136    },
137    #[error("authentication error with {provider}'s API: {message}")]
138    AuthenticationError {
139        provider: LanguageModelProviderName,
140        message: String,
141    },
142    #[error("permission error with {provider}'s API: {message}")]
143    PermissionError {
144        provider: LanguageModelProviderName,
145        message: String,
146    },
147    #[error("language model provider API endpoint not found")]
148    ApiEndpointNotFound { provider: LanguageModelProviderName },
149    #[error("I/O error reading response from {provider}'s API")]
150    ApiReadResponseError {
151        provider: LanguageModelProviderName,
152        #[source]
153        error: io::Error,
154    },
155    #[error("error serializing request to {provider} API")]
156    SerializeRequest {
157        provider: LanguageModelProviderName,
158        #[source]
159        error: serde_json::Error,
160    },
161    #[error("error building request body to {provider} API")]
162    BuildRequestBody {
163        provider: LanguageModelProviderName,
164        #[source]
165        error: http::Error,
166    },
167    #[error("error sending HTTP request to {provider} API")]
168    HttpSend {
169        provider: LanguageModelProviderName,
170        #[source]
171        error: anyhow::Error,
172    },
173    #[error("error deserializing {provider} API response")]
174    DeserializeResponse {
175        provider: LanguageModelProviderName,
176        #[source]
177        error: serde_json::Error,
178    },
179
180    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
181    #[error(transparent)]
182    Other(#[from] anyhow::Error),
183}
184
185impl LanguageModelCompletionError {
186    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
187        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
188        let upstream_status = error_json
189            .get("upstream_status")
190            .and_then(|v| v.as_u64())
191            .and_then(|status| u16::try_from(status).ok())
192            .and_then(|status| StatusCode::from_u16(status).ok())?;
193        let inner_message = error_json
194            .get("message")
195            .and_then(|v| v.as_str())
196            .unwrap_or(message)
197            .to_string();
198        Some((upstream_status, inner_message))
199    }
200
201    pub fn from_cloud_failure(
202        upstream_provider: LanguageModelProviderName,
203        code: String,
204        message: String,
205        retry_after: Option<Duration>,
206    ) -> Self {
207        if let Some(tokens) = parse_prompt_too_long(&message) {
208            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
209            // to be reported. This is a temporary workaround to handle this in the case where the
210            // token limit has been exceeded.
211            Self::PromptTooLarge {
212                tokens: Some(tokens),
213            }
214        } else if code == "upstream_http_error" {
215            if let Some((upstream_status, inner_message)) =
216                Self::parse_upstream_error_json(&message)
217            {
218                return Self::from_http_status(
219                    upstream_provider,
220                    upstream_status,
221                    inner_message,
222                    retry_after,
223                );
224            }
225            anyhow!("completion request failed, code: {code}, message: {message}").into()
226        } else if let Some(status_code) = code
227            .strip_prefix("upstream_http_")
228            .and_then(|code| StatusCode::from_str(code).ok())
229        {
230            Self::from_http_status(upstream_provider, status_code, message, retry_after)
231        } else if let Some(status_code) = code
232            .strip_prefix("http_")
233            .and_then(|code| StatusCode::from_str(code).ok())
234        {
235            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
236        } else {
237            anyhow!("completion request failed, code: {code}, message: {message}").into()
238        }
239    }
240
241    pub fn from_http_status(
242        provider: LanguageModelProviderName,
243        status_code: StatusCode,
244        message: String,
245        retry_after: Option<Duration>,
246    ) -> Self {
247        match status_code {
248            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
249            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
250            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
251            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
252            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
253                tokens: parse_prompt_too_long(&message),
254            },
255            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
256                provider,
257                retry_after,
258            },
259            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
260            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
261                provider,
262                retry_after,
263            },
264            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
265                provider,
266                retry_after,
267            },
268            _ => Self::HttpResponseError {
269                provider,
270                status_code,
271                message,
272            },
273        }
274    }
275}
276
277impl From<AnthropicError> for LanguageModelCompletionError {
278    fn from(error: AnthropicError) -> Self {
279        let provider = ANTHROPIC_PROVIDER_NAME;
280        match error {
281            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
282            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
283            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
284            AnthropicError::DeserializeResponse(error) => {
285                Self::DeserializeResponse { provider, error }
286            }
287            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
288            AnthropicError::HttpResponseError {
289                status_code,
290                message,
291            } => Self::HttpResponseError {
292                provider,
293                status_code,
294                message,
295            },
296            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
297                provider,
298                retry_after: Some(retry_after),
299            },
300            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
301                provider,
302                retry_after,
303            },
304            AnthropicError::ApiError(api_error) => api_error.into(),
305        }
306    }
307}
308
309impl From<anthropic::ApiError> for LanguageModelCompletionError {
310    fn from(error: anthropic::ApiError) -> Self {
311        use anthropic::ApiErrorCode::*;
312        let provider = ANTHROPIC_PROVIDER_NAME;
313        match error.code() {
314            Some(code) => match code {
315                InvalidRequestError => Self::BadRequestFormat {
316                    provider,
317                    message: error.message,
318                },
319                AuthenticationError => Self::AuthenticationError {
320                    provider,
321                    message: error.message,
322                },
323                PermissionError => Self::PermissionError {
324                    provider,
325                    message: error.message,
326                },
327                NotFoundError => Self::ApiEndpointNotFound { provider },
328                RequestTooLarge => Self::PromptTooLarge {
329                    tokens: parse_prompt_too_long(&error.message),
330                },
331                RateLimitError => Self::RateLimitExceeded {
332                    provider,
333                    retry_after: None,
334                },
335                ApiError => Self::ApiInternalServerError {
336                    provider,
337                    message: error.message,
338                },
339                OverloadedError => Self::ServerOverloaded {
340                    provider,
341                    retry_after: None,
342                },
343            },
344            None => Self::Other(error.into()),
345        }
346    }
347}
348
349impl From<OpenRouterError> for LanguageModelCompletionError {
350    fn from(error: OpenRouterError) -> Self {
351        let provider = LanguageModelProviderName::new("OpenRouter");
352        match error {
353            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
354            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
355            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
356            OpenRouterError::DeserializeResponse(error) => {
357                Self::DeserializeResponse { provider, error }
358            }
359            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
360            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
361                provider,
362                retry_after: Some(retry_after),
363            },
364            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
365                provider,
366                retry_after,
367            },
368            OpenRouterError::ApiError(api_error) => api_error.into(),
369        }
370    }
371}
372
373impl From<open_router::ApiError> for LanguageModelCompletionError {
374    fn from(error: open_router::ApiError) -> Self {
375        use open_router::ApiErrorCode::*;
376        let provider = LanguageModelProviderName::new("OpenRouter");
377        match error.code {
378            InvalidRequestError => Self::BadRequestFormat {
379                provider,
380                message: error.message,
381            },
382            AuthenticationError => Self::AuthenticationError {
383                provider,
384                message: error.message,
385            },
386            PaymentRequiredError => Self::AuthenticationError {
387                provider,
388                message: format!("Payment required: {}", error.message),
389            },
390            PermissionError => Self::PermissionError {
391                provider,
392                message: error.message,
393            },
394            RequestTimedOut => Self::HttpResponseError {
395                provider,
396                status_code: StatusCode::REQUEST_TIMEOUT,
397                message: error.message,
398            },
399            RateLimitError => Self::RateLimitExceeded {
400                provider,
401                retry_after: None,
402            },
403            ApiError => Self::ApiInternalServerError {
404                provider,
405                message: error.message,
406            },
407            OverloadedError => Self::ServerOverloaded {
408                provider,
409                retry_after: None,
410            },
411        }
412    }
413}
414
415#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
416#[serde(rename_all = "snake_case")]
417pub enum StopReason {
418    EndTurn,
419    MaxTokens,
420    ToolUse,
421    Refusal,
422}
423
424#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
425pub struct TokenUsage {
426    #[serde(default, skip_serializing_if = "is_default")]
427    pub input_tokens: u64,
428    #[serde(default, skip_serializing_if = "is_default")]
429    pub output_tokens: u64,
430    #[serde(default, skip_serializing_if = "is_default")]
431    pub cache_creation_input_tokens: u64,
432    #[serde(default, skip_serializing_if = "is_default")]
433    pub cache_read_input_tokens: u64,
434}
435
436impl TokenUsage {
437    pub fn total_tokens(&self) -> u64 {
438        self.input_tokens
439            + self.output_tokens
440            + self.cache_read_input_tokens
441            + self.cache_creation_input_tokens
442    }
443}
444
445impl Add<TokenUsage> for TokenUsage {
446    type Output = Self;
447
448    fn add(self, other: Self) -> Self {
449        Self {
450            input_tokens: self.input_tokens + other.input_tokens,
451            output_tokens: self.output_tokens + other.output_tokens,
452            cache_creation_input_tokens: self.cache_creation_input_tokens
453                + other.cache_creation_input_tokens,
454            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
455        }
456    }
457}
458
459impl Sub<TokenUsage> for TokenUsage {
460    type Output = Self;
461
462    fn sub(self, other: Self) -> Self {
463        Self {
464            input_tokens: self.input_tokens - other.input_tokens,
465            output_tokens: self.output_tokens - other.output_tokens,
466            cache_creation_input_tokens: self.cache_creation_input_tokens
467                - other.cache_creation_input_tokens,
468            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
469        }
470    }
471}
472
473#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
474pub struct LanguageModelToolUseId(Arc<str>);
475
476impl fmt::Display for LanguageModelToolUseId {
477    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478        write!(f, "{}", self.0)
479    }
480}
481
482impl<T> From<T> for LanguageModelToolUseId
483where
484    T: Into<Arc<str>>,
485{
486    fn from(value: T) -> Self {
487        Self(value.into())
488    }
489}
490
491#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
492pub struct LanguageModelToolUse {
493    pub id: LanguageModelToolUseId,
494    pub name: Arc<str>,
495    pub raw_input: String,
496    pub input: serde_json::Value,
497    pub is_input_complete: bool,
498    /// Thought signature the model sent us. Some models require that this
499    /// signature be preserved and sent back in conversation history for validation.
500    pub thought_signature: Option<String>,
501}
502
503pub struct LanguageModelTextStream {
504    pub message_id: Option<String>,
505    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
506    // Has complete token usage after the stream has finished
507    pub last_token_usage: Arc<Mutex<TokenUsage>>,
508}
509
510impl Default for LanguageModelTextStream {
511    fn default() -> Self {
512        Self {
513            message_id: None,
514            stream: Box::pin(futures::stream::empty()),
515            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
516        }
517    }
518}
519
520pub trait LanguageModel: Send + Sync {
521    fn id(&self) -> LanguageModelId;
522    fn name(&self) -> LanguageModelName;
523    fn provider_id(&self) -> LanguageModelProviderId;
524    fn provider_name(&self) -> LanguageModelProviderName;
525    fn upstream_provider_id(&self) -> LanguageModelProviderId {
526        self.provider_id()
527    }
528    fn upstream_provider_name(&self) -> LanguageModelProviderName {
529        self.provider_name()
530    }
531
532    fn telemetry_id(&self) -> String;
533
534    fn api_key(&self, _cx: &App) -> Option<String> {
535        None
536    }
537
538    /// Whether this model supports images
539    fn supports_images(&self) -> bool;
540
541    /// Whether this model supports tools.
542    fn supports_tools(&self) -> bool;
543
544    /// Whether this model supports choosing which tool to use.
545    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
546
547    /// Returns whether this model supports "burn mode";
548    fn supports_burn_mode(&self) -> bool {
549        false
550    }
551
552    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
553        LanguageModelToolSchemaFormat::JsonSchema
554    }
555
556    fn max_token_count(&self) -> u64;
557    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
558    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
559        None
560    }
561    fn max_output_tokens(&self) -> Option<u64> {
562        None
563    }
564
565    fn count_tokens(
566        &self,
567        request: LanguageModelRequest,
568        cx: &App,
569    ) -> BoxFuture<'static, Result<u64>>;
570
571    fn stream_completion(
572        &self,
573        request: LanguageModelRequest,
574        cx: &AsyncApp,
575    ) -> BoxFuture<
576        'static,
577        Result<
578            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
579            LanguageModelCompletionError,
580        >,
581    >;
582
583    fn stream_completion_text(
584        &self,
585        request: LanguageModelRequest,
586        cx: &AsyncApp,
587    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
588        let future = self.stream_completion(request, cx);
589
590        async move {
591            let events = future.await?;
592            let mut events = events.fuse();
593            let mut message_id = None;
594            let mut first_item_text = None;
595            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
596
597            if let Some(first_event) = events.next().await {
598                match first_event {
599                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
600                        message_id = Some(id);
601                    }
602                    Ok(LanguageModelCompletionEvent::Text(text)) => {
603                        first_item_text = Some(text);
604                    }
605                    _ => (),
606                }
607            }
608
609            let stream = futures::stream::iter(first_item_text.map(Ok))
610                .chain(events.filter_map({
611                    let last_token_usage = last_token_usage.clone();
612                    move |result| {
613                        let last_token_usage = last_token_usage.clone();
614                        async move {
615                            match result {
616                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
617                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
618                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
619                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
620                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
621                                Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
622                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
623                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
624                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
625                                    ..
626                                }) => None,
627                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
628                                    *last_token_usage.lock() = token_usage;
629                                    None
630                                }
631                                Err(err) => Some(Err(err)),
632                            }
633                        }
634                    }
635                }))
636                .boxed();
637
638            Ok(LanguageModelTextStream {
639                message_id,
640                stream,
641                last_token_usage,
642            })
643        }
644        .boxed()
645    }
646
647    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
648        None
649    }
650
651    #[cfg(any(test, feature = "test-support"))]
652    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
653        unimplemented!()
654    }
655}
656
657pub trait LanguageModelExt: LanguageModel {
658    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
659        match mode {
660            CompletionMode::Normal => self.max_token_count(),
661            CompletionMode::Max => self
662                .max_token_count_in_burn_mode()
663                .unwrap_or_else(|| self.max_token_count()),
664        }
665    }
666}
667impl LanguageModelExt for dyn LanguageModel {}
668
669/// An error that occurred when trying to authenticate the language model provider.
670#[derive(Debug, Error)]
671pub enum AuthenticateError {
672    #[error("connection refused")]
673    ConnectionRefused,
674    #[error("credentials not found")]
675    CredentialsNotFound,
676    #[error(transparent)]
677    Other(#[from] anyhow::Error),
678}
679
680pub trait LanguageModelProvider: 'static {
681    fn id(&self) -> LanguageModelProviderId;
682    fn name(&self) -> LanguageModelProviderName;
683    fn icon(&self) -> IconName {
684        IconName::ZedAssistant
685    }
686    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
687    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
688    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
689    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
690        Vec::new()
691    }
692    fn is_authenticated(&self, cx: &App) -> bool;
693    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
694    fn configuration_view(
695        &self,
696        target_agent: ConfigurationViewTargetAgent,
697        window: &mut Window,
698        cx: &mut App,
699    ) -> AnyView;
700    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
701}
702
703#[derive(Default, Clone)]
704pub enum ConfigurationViewTargetAgent {
705    #[default]
706    ZedAgent,
707    Other(SharedString),
708}
709
710#[derive(PartialEq, Eq)]
711pub enum LanguageModelProviderTosView {
712    /// When there are some past interactions in the Agent Panel.
713    ThreadEmptyState,
714    /// When there are no past interactions in the Agent Panel.
715    ThreadFreshStart,
716    TextThreadPopup,
717    Configuration,
718}
719
720pub trait LanguageModelProviderState: 'static {
721    type ObservableEntity;
722
723    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
724
725    fn subscribe<T: 'static>(
726        &self,
727        cx: &mut gpui::Context<T>,
728        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
729    ) -> Option<gpui::Subscription> {
730        let entity = self.observable_entity()?;
731        Some(cx.observe(&entity, move |this, _, cx| {
732            callback(this, cx);
733        }))
734    }
735}
736
737#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
738pub struct LanguageModelId(pub SharedString);
739
740#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
741pub struct LanguageModelName(pub SharedString);
742
743#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
744pub struct LanguageModelProviderId(pub SharedString);
745
746#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
747pub struct LanguageModelProviderName(pub SharedString);
748
749impl LanguageModelProviderId {
750    pub const fn new(id: &'static str) -> Self {
751        Self(SharedString::new_static(id))
752    }
753}
754
755impl LanguageModelProviderName {
756    pub const fn new(id: &'static str) -> Self {
757        Self(SharedString::new_static(id))
758    }
759}
760
761impl fmt::Display for LanguageModelProviderId {
762    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
763        write!(f, "{}", self.0)
764    }
765}
766
767impl fmt::Display for LanguageModelProviderName {
768    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
769        write!(f, "{}", self.0)
770    }
771}
772
773impl From<String> for LanguageModelId {
774    fn from(value: String) -> Self {
775        Self(SharedString::from(value))
776    }
777}
778
779impl From<String> for LanguageModelName {
780    fn from(value: String) -> Self {
781        Self(SharedString::from(value))
782    }
783}
784
785impl From<String> for LanguageModelProviderId {
786    fn from(value: String) -> Self {
787        Self(SharedString::from(value))
788    }
789}
790
791impl From<String> for LanguageModelProviderName {
792    fn from(value: String) -> Self {
793        Self(SharedString::from(value))
794    }
795}
796
797impl From<Arc<str>> for LanguageModelProviderId {
798    fn from(value: Arc<str>) -> Self {
799        Self(SharedString::from(value))
800    }
801}
802
803impl From<Arc<str>> for LanguageModelProviderName {
804    fn from(value: Arc<str>) -> Self {
805        Self(SharedString::from(value))
806    }
807}
808
809#[cfg(test)]
810mod tests {
811    use super::*;
812
813    #[test]
814    fn test_from_cloud_failure_with_upstream_http_error() {
815        let error = LanguageModelCompletionError::from_cloud_failure(
816            String::from("anthropic").into(),
817            "upstream_http_error".to_string(),
818            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
819            None,
820        );
821
822        match error {
823            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
824                assert_eq!(provider.0, "anthropic");
825            }
826            _ => panic!(
827                "Expected ServerOverloaded error for 503 status, got: {:?}",
828                error
829            ),
830        }
831
832        let error = LanguageModelCompletionError::from_cloud_failure(
833            String::from("anthropic").into(),
834            "upstream_http_error".to_string(),
835            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
836            None,
837        );
838
839        match error {
840            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
841                assert_eq!(provider.0, "anthropic");
842                assert_eq!(message, "Internal server error");
843            }
844            _ => panic!(
845                "Expected ApiInternalServerError for 500 status, got: {:?}",
846                error
847            ),
848        }
849    }
850
851    #[test]
852    fn test_from_cloud_failure_with_standard_format() {
853        let error = LanguageModelCompletionError::from_cloud_failure(
854            String::from("anthropic").into(),
855            "upstream_http_503".to_string(),
856            "Service unavailable".to_string(),
857            None,
858        );
859
860        match error {
861            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
862                assert_eq!(provider.0, "anthropic");
863            }
864            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
865        }
866    }
867
868    #[test]
869    fn test_upstream_http_error_connection_timeout() {
870        let error = LanguageModelCompletionError::from_cloud_failure(
871            String::from("anthropic").into(),
872            "upstream_http_error".to_string(),
873            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
874            None,
875        );
876
877        match error {
878            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
879                assert_eq!(provider.0, "anthropic");
880            }
881            _ => panic!(
882                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
883                error
884            ),
885        }
886
887        let error = LanguageModelCompletionError::from_cloud_failure(
888            String::from("anthropic").into(),
889            "upstream_http_error".to_string(),
890            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
891            None,
892        );
893
894        match error {
895            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
896                assert_eq!(provider.0, "anthropic");
897                assert_eq!(
898                    message,
899                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
900                );
901            }
902            _ => panic!(
903                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
904                error
905            ),
906        }
907    }
908
909    #[test]
910    fn test_language_model_tool_use_serializes_with_signature() {
911        use serde_json::json;
912
913        let tool_use = LanguageModelToolUse {
914            id: LanguageModelToolUseId::from("test_id"),
915            name: "test_tool".into(),
916            raw_input: json!({"arg": "value"}).to_string(),
917            input: json!({"arg": "value"}),
918            is_input_complete: true,
919            thought_signature: Some("test_signature".to_string()),
920        };
921
922        let serialized = serde_json::to_value(&tool_use).unwrap();
923
924        assert_eq!(serialized["id"], "test_id");
925        assert_eq!(serialized["name"], "test_tool");
926        assert_eq!(serialized["thought_signature"], "test_signature");
927    }
928
929    #[test]
930    fn test_language_model_tool_use_deserializes_with_missing_signature() {
931        use serde_json::json;
932
933        let json = json!({
934            "id": "test_id",
935            "name": "test_tool",
936            "raw_input": "{\"arg\":\"value\"}",
937            "input": {"arg": "value"},
938            "is_input_complete": true
939        });
940
941        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
942
943        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
944        assert_eq!(tool_use.name.as_ref(), "test_tool");
945        assert_eq!(tool_use.thought_signature, None);
946    }
947
948    #[test]
949    fn test_language_model_tool_use_round_trip_with_signature() {
950        use serde_json::json;
951
952        let original = LanguageModelToolUse {
953            id: LanguageModelToolUseId::from("round_trip_id"),
954            name: "round_trip_tool".into(),
955            raw_input: json!({"key": "value"}).to_string(),
956            input: json!({"key": "value"}),
957            is_input_complete: true,
958            thought_signature: Some("round_trip_sig".to_string()),
959        };
960
961        let serialized = serde_json::to_value(&original).unwrap();
962        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
963
964        assert_eq!(deserialized.id, original.id);
965        assert_eq!(deserialized.name, original.name);
966        assert_eq!(deserialized.thought_signature, original.thought_signature);
967    }
968
969    #[test]
970    fn test_language_model_tool_use_round_trip_without_signature() {
971        use serde_json::json;
972
973        let original = LanguageModelToolUse {
974            id: LanguageModelToolUseId::from("no_sig_id"),
975            name: "no_sig_tool".into(),
976            raw_input: json!({"arg": "value"}).to_string(),
977            input: json!({"arg": "value"}),
978            is_input_complete: true,
979            thought_signature: None,
980        };
981
982        let serialized = serde_json::to_value(&original).unwrap();
983        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
984
985        assert_eq!(deserialized.id, original.id);
986        assert_eq!(deserialized.name, original.name);
987        assert_eq!(deserialized.thought_signature, None);
988    }
989}