language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7pub mod tool_schema;
  8
  9#[cfg(any(test, feature = "test-support"))]
 10pub mod fake_provider;
 11
 12use anthropic::{AnthropicError, parse_prompt_too_long};
 13use anyhow::{Result, anyhow};
 14use client::Client;
 15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 16use futures::FutureExt;
 17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 19use http_client::{StatusCode, http};
 20use icons::IconName;
 21use open_router::OpenRouterError;
 22use parking_lot::Mutex;
 23use serde::{Deserialize, Serialize};
 24pub use settings::LanguageModelCacheConfiguration;
 25use std::ops::{Add, Sub};
 26use std::str::FromStr;
 27use std::sync::Arc;
 28use std::time::Duration;
 29use std::{fmt, io};
 30use thiserror::Error;
 31use util::serde::is_default;
 32
 33pub use crate::model::*;
 34pub use crate::rate_limiter::*;
 35pub use crate::registry::*;
 36pub use crate::request::*;
 37pub use crate::role::*;
 38pub use crate::telemetry::*;
 39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
 40
 41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 42    LanguageModelProviderId::new("anthropic");
 43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 44    LanguageModelProviderName::new("Anthropic");
 45
 46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 48    LanguageModelProviderName::new("Google AI");
 49
 50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 52    LanguageModelProviderName::new("OpenAI");
 53
 54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
 55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
 56
 57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 59    LanguageModelProviderName::new("Zed");
 60
 61pub fn init(client: Arc<Client>, cx: &mut App) {
 62    init_settings(cx);
 63    RefreshLlmTokenListener::register(client, cx);
 64}
 65
 66pub fn init_settings(cx: &mut App) {
 67    registry::init(cx);
 68}
 69
 70/// A completion event from a language model.
 71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 72pub enum LanguageModelCompletionEvent {
 73    StatusUpdate(CompletionRequestStatus),
 74    Stop(StopReason),
 75    Text(String),
 76    Thinking {
 77        text: String,
 78        signature: Option<String>,
 79    },
 80    RedactedThinking {
 81        data: String,
 82    },
 83    ToolUse(LanguageModelToolUse),
 84    ToolUseJsonParseError {
 85        id: LanguageModelToolUseId,
 86        tool_name: Arc<str>,
 87        raw_input: Arc<str>,
 88        json_parse_error: String,
 89    },
 90    StartMessage {
 91        message_id: String,
 92    },
 93    UsageUpdate(TokenUsage),
 94}
 95
 96#[derive(Error, Debug)]
 97pub enum LanguageModelCompletionError {
 98    #[error("prompt too large for context window")]
 99    PromptTooLarge { tokens: Option<u64> },
100    #[error("missing {provider} API key")]
101    NoApiKey { provider: LanguageModelProviderName },
102    #[error("{provider}'s API rate limit exceeded")]
103    RateLimitExceeded {
104        provider: LanguageModelProviderName,
105        retry_after: Option<Duration>,
106    },
107    #[error("{provider}'s API servers are overloaded right now")]
108    ServerOverloaded {
109        provider: LanguageModelProviderName,
110        retry_after: Option<Duration>,
111    },
112    #[error("{provider}'s API server reported an internal server error: {message}")]
113    ApiInternalServerError {
114        provider: LanguageModelProviderName,
115        message: String,
116    },
117    #[error("{message}")]
118    UpstreamProviderError {
119        message: String,
120        status: StatusCode,
121        retry_after: Option<Duration>,
122    },
123    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
124    HttpResponseError {
125        provider: LanguageModelProviderName,
126        status_code: StatusCode,
127        message: String,
128    },
129
130    // Client errors
131    #[error("invalid request format to {provider}'s API: {message}")]
132    BadRequestFormat {
133        provider: LanguageModelProviderName,
134        message: String,
135    },
136    #[error("authentication error with {provider}'s API: {message}")]
137    AuthenticationError {
138        provider: LanguageModelProviderName,
139        message: String,
140    },
141    #[error("permission error with {provider}'s API: {message}")]
142    PermissionError {
143        provider: LanguageModelProviderName,
144        message: String,
145    },
146    #[error("language model provider API endpoint not found")]
147    ApiEndpointNotFound { provider: LanguageModelProviderName },
148    #[error("I/O error reading response from {provider}'s API")]
149    ApiReadResponseError {
150        provider: LanguageModelProviderName,
151        #[source]
152        error: io::Error,
153    },
154    #[error("error serializing request to {provider} API")]
155    SerializeRequest {
156        provider: LanguageModelProviderName,
157        #[source]
158        error: serde_json::Error,
159    },
160    #[error("error building request body to {provider} API")]
161    BuildRequestBody {
162        provider: LanguageModelProviderName,
163        #[source]
164        error: http::Error,
165    },
166    #[error("error sending HTTP request to {provider} API")]
167    HttpSend {
168        provider: LanguageModelProviderName,
169        #[source]
170        error: anyhow::Error,
171    },
172    #[error("error deserializing {provider} API response")]
173    DeserializeResponse {
174        provider: LanguageModelProviderName,
175        #[source]
176        error: serde_json::Error,
177    },
178
179    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
180    #[error(transparent)]
181    Other(#[from] anyhow::Error),
182}
183
184impl LanguageModelCompletionError {
185    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
186        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
187        let upstream_status = error_json
188            .get("upstream_status")
189            .and_then(|v| v.as_u64())
190            .and_then(|status| u16::try_from(status).ok())
191            .and_then(|status| StatusCode::from_u16(status).ok())?;
192        let inner_message = error_json
193            .get("message")
194            .and_then(|v| v.as_str())
195            .unwrap_or(message)
196            .to_string();
197        Some((upstream_status, inner_message))
198    }
199
200    pub fn from_cloud_failure(
201        upstream_provider: LanguageModelProviderName,
202        code: String,
203        message: String,
204        retry_after: Option<Duration>,
205    ) -> Self {
206        if let Some(tokens) = parse_prompt_too_long(&message) {
207            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
208            // to be reported. This is a temporary workaround to handle this in the case where the
209            // token limit has been exceeded.
210            Self::PromptTooLarge {
211                tokens: Some(tokens),
212            }
213        } else if code == "upstream_http_error" {
214            if let Some((upstream_status, inner_message)) =
215                Self::parse_upstream_error_json(&message)
216            {
217                return Self::from_http_status(
218                    upstream_provider,
219                    upstream_status,
220                    inner_message,
221                    retry_after,
222                );
223            }
224            anyhow!("completion request failed, code: {code}, message: {message}").into()
225        } else if let Some(status_code) = code
226            .strip_prefix("upstream_http_")
227            .and_then(|code| StatusCode::from_str(code).ok())
228        {
229            Self::from_http_status(upstream_provider, status_code, message, retry_after)
230        } else if let Some(status_code) = code
231            .strip_prefix("http_")
232            .and_then(|code| StatusCode::from_str(code).ok())
233        {
234            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
235        } else {
236            anyhow!("completion request failed, code: {code}, message: {message}").into()
237        }
238    }
239
240    pub fn from_http_status(
241        provider: LanguageModelProviderName,
242        status_code: StatusCode,
243        message: String,
244        retry_after: Option<Duration>,
245    ) -> Self {
246        match status_code {
247            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
248            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
249            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
250            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
251            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
252                tokens: parse_prompt_too_long(&message),
253            },
254            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
255                provider,
256                retry_after,
257            },
258            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
259            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
260                provider,
261                retry_after,
262            },
263            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
264                provider,
265                retry_after,
266            },
267            _ => Self::HttpResponseError {
268                provider,
269                status_code,
270                message,
271            },
272        }
273    }
274}
275
276impl From<AnthropicError> for LanguageModelCompletionError {
277    fn from(error: AnthropicError) -> Self {
278        let provider = ANTHROPIC_PROVIDER_NAME;
279        match error {
280            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
281            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
282            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
283            AnthropicError::DeserializeResponse(error) => {
284                Self::DeserializeResponse { provider, error }
285            }
286            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
287            AnthropicError::HttpResponseError {
288                status_code,
289                message,
290            } => Self::HttpResponseError {
291                provider,
292                status_code,
293                message,
294            },
295            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
296                provider,
297                retry_after: Some(retry_after),
298            },
299            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
300                provider,
301                retry_after,
302            },
303            AnthropicError::ApiError(api_error) => api_error.into(),
304        }
305    }
306}
307
308impl From<anthropic::ApiError> for LanguageModelCompletionError {
309    fn from(error: anthropic::ApiError) -> Self {
310        use anthropic::ApiErrorCode::*;
311        let provider = ANTHROPIC_PROVIDER_NAME;
312        match error.code() {
313            Some(code) => match code {
314                InvalidRequestError => Self::BadRequestFormat {
315                    provider,
316                    message: error.message,
317                },
318                AuthenticationError => Self::AuthenticationError {
319                    provider,
320                    message: error.message,
321                },
322                PermissionError => Self::PermissionError {
323                    provider,
324                    message: error.message,
325                },
326                NotFoundError => Self::ApiEndpointNotFound { provider },
327                RequestTooLarge => Self::PromptTooLarge {
328                    tokens: parse_prompt_too_long(&error.message),
329                },
330                RateLimitError => Self::RateLimitExceeded {
331                    provider,
332                    retry_after: None,
333                },
334                ApiError => Self::ApiInternalServerError {
335                    provider,
336                    message: error.message,
337                },
338                OverloadedError => Self::ServerOverloaded {
339                    provider,
340                    retry_after: None,
341                },
342            },
343            None => Self::Other(error.into()),
344        }
345    }
346}
347
348impl From<OpenRouterError> for LanguageModelCompletionError {
349    fn from(error: OpenRouterError) -> Self {
350        let provider = LanguageModelProviderName::new("OpenRouter");
351        match error {
352            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
353            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
354            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
355            OpenRouterError::DeserializeResponse(error) => {
356                Self::DeserializeResponse { provider, error }
357            }
358            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
359            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
360                provider,
361                retry_after: Some(retry_after),
362            },
363            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
364                provider,
365                retry_after,
366            },
367            OpenRouterError::ApiError(api_error) => api_error.into(),
368        }
369    }
370}
371
372impl From<open_router::ApiError> for LanguageModelCompletionError {
373    fn from(error: open_router::ApiError) -> Self {
374        use open_router::ApiErrorCode::*;
375        let provider = LanguageModelProviderName::new("OpenRouter");
376        match error.code {
377            InvalidRequestError => Self::BadRequestFormat {
378                provider,
379                message: error.message,
380            },
381            AuthenticationError => Self::AuthenticationError {
382                provider,
383                message: error.message,
384            },
385            PaymentRequiredError => Self::AuthenticationError {
386                provider,
387                message: format!("Payment required: {}", error.message),
388            },
389            PermissionError => Self::PermissionError {
390                provider,
391                message: error.message,
392            },
393            RequestTimedOut => Self::HttpResponseError {
394                provider,
395                status_code: StatusCode::REQUEST_TIMEOUT,
396                message: error.message,
397            },
398            RateLimitError => Self::RateLimitExceeded {
399                provider,
400                retry_after: None,
401            },
402            ApiError => Self::ApiInternalServerError {
403                provider,
404                message: error.message,
405            },
406            OverloadedError => Self::ServerOverloaded {
407                provider,
408                retry_after: None,
409            },
410        }
411    }
412}
413
414#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
415#[serde(rename_all = "snake_case")]
416pub enum StopReason {
417    EndTurn,
418    MaxTokens,
419    ToolUse,
420    Refusal,
421}
422
423#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
424pub struct TokenUsage {
425    #[serde(default, skip_serializing_if = "is_default")]
426    pub input_tokens: u64,
427    #[serde(default, skip_serializing_if = "is_default")]
428    pub output_tokens: u64,
429    #[serde(default, skip_serializing_if = "is_default")]
430    pub cache_creation_input_tokens: u64,
431    #[serde(default, skip_serializing_if = "is_default")]
432    pub cache_read_input_tokens: u64,
433}
434
435impl TokenUsage {
436    pub fn total_tokens(&self) -> u64 {
437        self.input_tokens
438            + self.output_tokens
439            + self.cache_read_input_tokens
440            + self.cache_creation_input_tokens
441    }
442}
443
444impl Add<TokenUsage> for TokenUsage {
445    type Output = Self;
446
447    fn add(self, other: Self) -> Self {
448        Self {
449            input_tokens: self.input_tokens + other.input_tokens,
450            output_tokens: self.output_tokens + other.output_tokens,
451            cache_creation_input_tokens: self.cache_creation_input_tokens
452                + other.cache_creation_input_tokens,
453            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
454        }
455    }
456}
457
458impl Sub<TokenUsage> for TokenUsage {
459    type Output = Self;
460
461    fn sub(self, other: Self) -> Self {
462        Self {
463            input_tokens: self.input_tokens - other.input_tokens,
464            output_tokens: self.output_tokens - other.output_tokens,
465            cache_creation_input_tokens: self.cache_creation_input_tokens
466                - other.cache_creation_input_tokens,
467            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
468        }
469    }
470}
471
472#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
473pub struct LanguageModelToolUseId(Arc<str>);
474
475impl fmt::Display for LanguageModelToolUseId {
476    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
477        write!(f, "{}", self.0)
478    }
479}
480
481impl<T> From<T> for LanguageModelToolUseId
482where
483    T: Into<Arc<str>>,
484{
485    fn from(value: T) -> Self {
486        Self(value.into())
487    }
488}
489
490#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
491pub struct LanguageModelToolUse {
492    pub id: LanguageModelToolUseId,
493    pub name: Arc<str>,
494    pub raw_input: String,
495    pub input: serde_json::Value,
496    pub is_input_complete: bool,
497    /// Thought signature the model sent us. Some models require that this
498    /// signature be preserved and sent back in conversation history for validation.
499    pub thought_signature: Option<String>,
500}
501
502pub struct LanguageModelTextStream {
503    pub message_id: Option<String>,
504    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
505    // Has complete token usage after the stream has finished
506    pub last_token_usage: Arc<Mutex<TokenUsage>>,
507}
508
509impl Default for LanguageModelTextStream {
510    fn default() -> Self {
511        Self {
512            message_id: None,
513            stream: Box::pin(futures::stream::empty()),
514            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
515        }
516    }
517}
518
519pub trait LanguageModel: Send + Sync {
520    fn id(&self) -> LanguageModelId;
521    fn name(&self) -> LanguageModelName;
522    fn provider_id(&self) -> LanguageModelProviderId;
523    fn provider_name(&self) -> LanguageModelProviderName;
524    fn upstream_provider_id(&self) -> LanguageModelProviderId {
525        self.provider_id()
526    }
527    fn upstream_provider_name(&self) -> LanguageModelProviderName {
528        self.provider_name()
529    }
530
531    fn telemetry_id(&self) -> String;
532
533    fn api_key(&self, _cx: &App) -> Option<String> {
534        None
535    }
536
537    /// Whether this model supports images
538    fn supports_images(&self) -> bool;
539
540    /// Whether this model supports tools.
541    fn supports_tools(&self) -> bool;
542
543    /// Whether this model supports choosing which tool to use.
544    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
545
546    /// Returns whether this model supports "burn mode";
547    fn supports_burn_mode(&self) -> bool {
548        false
549    }
550
551    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
552        LanguageModelToolSchemaFormat::JsonSchema
553    }
554
555    fn max_token_count(&self) -> u64;
556    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
557    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
558        None
559    }
560    fn max_output_tokens(&self) -> Option<u64> {
561        None
562    }
563
564    fn count_tokens(
565        &self,
566        request: LanguageModelRequest,
567        cx: &App,
568    ) -> BoxFuture<'static, Result<u64>>;
569
570    fn stream_completion(
571        &self,
572        request: LanguageModelRequest,
573        cx: &AsyncApp,
574    ) -> BoxFuture<
575        'static,
576        Result<
577            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
578            LanguageModelCompletionError,
579        >,
580    >;
581
582    fn stream_completion_text(
583        &self,
584        request: LanguageModelRequest,
585        cx: &AsyncApp,
586    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
587        let future = self.stream_completion(request, cx);
588
589        async move {
590            let events = future.await?;
591            let mut events = events.fuse();
592            let mut message_id = None;
593            let mut first_item_text = None;
594            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
595
596            if let Some(first_event) = events.next().await {
597                match first_event {
598                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
599                        message_id = Some(id);
600                    }
601                    Ok(LanguageModelCompletionEvent::Text(text)) => {
602                        first_item_text = Some(text);
603                    }
604                    _ => (),
605                }
606            }
607
608            let stream = futures::stream::iter(first_item_text.map(Ok))
609                .chain(events.filter_map({
610                    let last_token_usage = last_token_usage.clone();
611                    move |result| {
612                        let last_token_usage = last_token_usage.clone();
613                        async move {
614                            match result {
615                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
616                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
617                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
618                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
619                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
620                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
621                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
622                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
623                                    ..
624                                }) => None,
625                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
626                                    *last_token_usage.lock() = token_usage;
627                                    None
628                                }
629                                Err(err) => Some(Err(err)),
630                            }
631                        }
632                    }
633                }))
634                .boxed();
635
636            Ok(LanguageModelTextStream {
637                message_id,
638                stream,
639                last_token_usage,
640            })
641        }
642        .boxed()
643    }
644
645    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
646        None
647    }
648
649    #[cfg(any(test, feature = "test-support"))]
650    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
651        unimplemented!()
652    }
653}
654
655pub trait LanguageModelExt: LanguageModel {
656    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
657        match mode {
658            CompletionMode::Normal => self.max_token_count(),
659            CompletionMode::Max => self
660                .max_token_count_in_burn_mode()
661                .unwrap_or_else(|| self.max_token_count()),
662        }
663    }
664}
665impl LanguageModelExt for dyn LanguageModel {}
666
667/// An error that occurred when trying to authenticate the language model provider.
668#[derive(Debug, Error)]
669pub enum AuthenticateError {
670    #[error("connection refused")]
671    ConnectionRefused,
672    #[error("credentials not found")]
673    CredentialsNotFound,
674    #[error(transparent)]
675    Other(#[from] anyhow::Error),
676}
677
678pub trait LanguageModelProvider: 'static {
679    fn id(&self) -> LanguageModelProviderId;
680    fn name(&self) -> LanguageModelProviderName;
681    fn icon(&self) -> IconName {
682        IconName::ZedAssistant
683    }
684    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
685    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
686    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
687    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
688        Vec::new()
689    }
690    fn is_authenticated(&self, cx: &App) -> bool;
691    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
692    fn configuration_view(
693        &self,
694        target_agent: ConfigurationViewTargetAgent,
695        window: &mut Window,
696        cx: &mut App,
697    ) -> AnyView;
698    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
699}
700
701#[derive(Default, Clone)]
702pub enum ConfigurationViewTargetAgent {
703    #[default]
704    ZedAgent,
705    Other(SharedString),
706}
707
708#[derive(PartialEq, Eq)]
709pub enum LanguageModelProviderTosView {
710    /// When there are some past interactions in the Agent Panel.
711    ThreadEmptyState,
712    /// When there are no past interactions in the Agent Panel.
713    ThreadFreshStart,
714    TextThreadPopup,
715    Configuration,
716}
717
718pub trait LanguageModelProviderState: 'static {
719    type ObservableEntity;
720
721    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
722
723    fn subscribe<T: 'static>(
724        &self,
725        cx: &mut gpui::Context<T>,
726        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
727    ) -> Option<gpui::Subscription> {
728        let entity = self.observable_entity()?;
729        Some(cx.observe(&entity, move |this, _, cx| {
730            callback(this, cx);
731        }))
732    }
733}
734
735#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
736pub struct LanguageModelId(pub SharedString);
737
738#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
739pub struct LanguageModelName(pub SharedString);
740
741#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
742pub struct LanguageModelProviderId(pub SharedString);
743
744#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
745pub struct LanguageModelProviderName(pub SharedString);
746
747impl LanguageModelProviderId {
748    pub const fn new(id: &'static str) -> Self {
749        Self(SharedString::new_static(id))
750    }
751}
752
753impl LanguageModelProviderName {
754    pub const fn new(id: &'static str) -> Self {
755        Self(SharedString::new_static(id))
756    }
757}
758
759impl fmt::Display for LanguageModelProviderId {
760    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
761        write!(f, "{}", self.0)
762    }
763}
764
765impl fmt::Display for LanguageModelProviderName {
766    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
767        write!(f, "{}", self.0)
768    }
769}
770
771impl From<String> for LanguageModelId {
772    fn from(value: String) -> Self {
773        Self(SharedString::from(value))
774    }
775}
776
777impl From<String> for LanguageModelName {
778    fn from(value: String) -> Self {
779        Self(SharedString::from(value))
780    }
781}
782
783impl From<String> for LanguageModelProviderId {
784    fn from(value: String) -> Self {
785        Self(SharedString::from(value))
786    }
787}
788
789impl From<String> for LanguageModelProviderName {
790    fn from(value: String) -> Self {
791        Self(SharedString::from(value))
792    }
793}
794
795impl From<Arc<str>> for LanguageModelProviderId {
796    fn from(value: Arc<str>) -> Self {
797        Self(SharedString::from(value))
798    }
799}
800
801impl From<Arc<str>> for LanguageModelProviderName {
802    fn from(value: Arc<str>) -> Self {
803        Self(SharedString::from(value))
804    }
805}
806
807#[cfg(test)]
808mod tests {
809    use super::*;
810
811    #[test]
812    fn test_from_cloud_failure_with_upstream_http_error() {
813        let error = LanguageModelCompletionError::from_cloud_failure(
814            String::from("anthropic").into(),
815            "upstream_http_error".to_string(),
816            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
817            None,
818        );
819
820        match error {
821            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
822                assert_eq!(provider.0, "anthropic");
823            }
824            _ => panic!(
825                "Expected ServerOverloaded error for 503 status, got: {:?}",
826                error
827            ),
828        }
829
830        let error = LanguageModelCompletionError::from_cloud_failure(
831            String::from("anthropic").into(),
832            "upstream_http_error".to_string(),
833            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
834            None,
835        );
836
837        match error {
838            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
839                assert_eq!(provider.0, "anthropic");
840                assert_eq!(message, "Internal server error");
841            }
842            _ => panic!(
843                "Expected ApiInternalServerError for 500 status, got: {:?}",
844                error
845            ),
846        }
847    }
848
849    #[test]
850    fn test_from_cloud_failure_with_standard_format() {
851        let error = LanguageModelCompletionError::from_cloud_failure(
852            String::from("anthropic").into(),
853            "upstream_http_503".to_string(),
854            "Service unavailable".to_string(),
855            None,
856        );
857
858        match error {
859            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
860                assert_eq!(provider.0, "anthropic");
861            }
862            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
863        }
864    }
865
866    #[test]
867    fn test_upstream_http_error_connection_timeout() {
868        let error = LanguageModelCompletionError::from_cloud_failure(
869            String::from("anthropic").into(),
870            "upstream_http_error".to_string(),
871            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
872            None,
873        );
874
875        match error {
876            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
877                assert_eq!(provider.0, "anthropic");
878            }
879            _ => panic!(
880                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
881                error
882            ),
883        }
884
885        let error = LanguageModelCompletionError::from_cloud_failure(
886            String::from("anthropic").into(),
887            "upstream_http_error".to_string(),
888            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
889            None,
890        );
891
892        match error {
893            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
894                assert_eq!(provider.0, "anthropic");
895                assert_eq!(
896                    message,
897                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
898                );
899            }
900            _ => panic!(
901                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
902                error
903            ),
904        }
905    }
906
907    #[test]
908    fn test_language_model_tool_use_serializes_with_signature() {
909        use serde_json::json;
910
911        let tool_use = LanguageModelToolUse {
912            id: LanguageModelToolUseId::from("test_id"),
913            name: "test_tool".into(),
914            raw_input: json!({"arg": "value"}).to_string(),
915            input: json!({"arg": "value"}),
916            is_input_complete: true,
917            thought_signature: Some("test_signature".to_string()),
918        };
919
920        let serialized = serde_json::to_value(&tool_use).unwrap();
921
922        assert_eq!(serialized["id"], "test_id");
923        assert_eq!(serialized["name"], "test_tool");
924        assert_eq!(serialized["thought_signature"], "test_signature");
925    }
926
927    #[test]
928    fn test_language_model_tool_use_deserializes_with_missing_signature() {
929        use serde_json::json;
930
931        let json = json!({
932            "id": "test_id",
933            "name": "test_tool",
934            "raw_input": "{\"arg\":\"value\"}",
935            "input": {"arg": "value"},
936            "is_input_complete": true
937        });
938
939        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
940
941        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
942        assert_eq!(tool_use.name.as_ref(), "test_tool");
943        assert_eq!(tool_use.thought_signature, None);
944    }
945
946    #[test]
947    fn test_language_model_tool_use_round_trip_with_signature() {
948        use serde_json::json;
949
950        let original = LanguageModelToolUse {
951            id: LanguageModelToolUseId::from("round_trip_id"),
952            name: "round_trip_tool".into(),
953            raw_input: json!({"key": "value"}).to_string(),
954            input: json!({"key": "value"}),
955            is_input_complete: true,
956            thought_signature: Some("round_trip_sig".to_string()),
957        };
958
959        let serialized = serde_json::to_value(&original).unwrap();
960        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
961
962        assert_eq!(deserialized.id, original.id);
963        assert_eq!(deserialized.name, original.name);
964        assert_eq!(deserialized.thought_signature, original.thought_signature);
965    }
966
967    #[test]
968    fn test_language_model_tool_use_round_trip_without_signature() {
969        use serde_json::json;
970
971        let original = LanguageModelToolUse {
972            id: LanguageModelToolUseId::from("no_sig_id"),
973            name: "no_sig_tool".into(),
974            raw_input: json!({"key": "value"}).to_string(),
975            input: json!({"key": "value"}),
976            is_input_complete: true,
977            thought_signature: None,
978        };
979
980        let serialized = serde_json::to_value(&original).unwrap();
981        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
982
983        assert_eq!(deserialized.id, original.id);
984        assert_eq!(deserialized.name, original.name);
985        assert_eq!(deserialized.thought_signature, None);
986    }
987}