1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7pub mod tool_schema;
8
9#[cfg(any(test, feature = "test-support"))]
10pub mod fake_provider;
11
12use anthropic::{AnthropicError, parse_prompt_too_long};
13use anyhow::{Result, anyhow};
14use client::Client;
15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
16use futures::FutureExt;
17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
19use http_client::{StatusCode, http};
20use icons::IconName;
21use open_router::OpenRouterError;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24pub use settings::LanguageModelCacheConfiguration;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
40
41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
42 LanguageModelProviderId::new("anthropic");
43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
44 LanguageModelProviderName::new("Anthropic");
45
46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Google AI");
49
50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("OpenAI");
53
54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
56
57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
59 LanguageModelProviderName::new("Zed");
60
61pub fn init(client: Arc<Client>, cx: &mut App) {
62 init_settings(cx);
63 RefreshLlmTokenListener::register(client, cx);
64}
65
66pub fn init_settings(cx: &mut App) {
67 registry::init(cx);
68}
69
70/// A completion event from a language model.
71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
72pub enum LanguageModelCompletionEvent {
73 StatusUpdate(CompletionRequestStatus),
74 Stop(StopReason),
75 Text(String),
76 Thinking {
77 text: String,
78 signature: Option<String>,
79 },
80 RedactedThinking {
81 data: String,
82 },
83 ToolUse(LanguageModelToolUse),
84 ToolUseJsonParseError {
85 id: LanguageModelToolUseId,
86 tool_name: Arc<str>,
87 raw_input: Arc<str>,
88 json_parse_error: String,
89 },
90 StartMessage {
91 message_id: String,
92 },
93 UsageUpdate(TokenUsage),
94}
95
96#[derive(Error, Debug)]
97pub enum LanguageModelCompletionError {
98 #[error("prompt too large for context window")]
99 PromptTooLarge { tokens: Option<u64> },
100 #[error("missing {provider} API key")]
101 NoApiKey { provider: LanguageModelProviderName },
102 #[error("{provider}'s API rate limit exceeded")]
103 RateLimitExceeded {
104 provider: LanguageModelProviderName,
105 retry_after: Option<Duration>,
106 },
107 #[error("{provider}'s API servers are overloaded right now")]
108 ServerOverloaded {
109 provider: LanguageModelProviderName,
110 retry_after: Option<Duration>,
111 },
112 #[error("{provider}'s API server reported an internal server error: {message}")]
113 ApiInternalServerError {
114 provider: LanguageModelProviderName,
115 message: String,
116 },
117 #[error("{message}")]
118 UpstreamProviderError {
119 message: String,
120 status: StatusCode,
121 retry_after: Option<Duration>,
122 },
123 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
124 HttpResponseError {
125 provider: LanguageModelProviderName,
126 status_code: StatusCode,
127 message: String,
128 },
129
130 // Client errors
131 #[error("invalid request format to {provider}'s API: {message}")]
132 BadRequestFormat {
133 provider: LanguageModelProviderName,
134 message: String,
135 },
136 #[error("authentication error with {provider}'s API: {message}")]
137 AuthenticationError {
138 provider: LanguageModelProviderName,
139 message: String,
140 },
141 #[error("permission error with {provider}'s API: {message}")]
142 PermissionError {
143 provider: LanguageModelProviderName,
144 message: String,
145 },
146 #[error("language model provider API endpoint not found")]
147 ApiEndpointNotFound { provider: LanguageModelProviderName },
148 #[error("I/O error reading response from {provider}'s API")]
149 ApiReadResponseError {
150 provider: LanguageModelProviderName,
151 #[source]
152 error: io::Error,
153 },
154 #[error("error serializing request to {provider} API")]
155 SerializeRequest {
156 provider: LanguageModelProviderName,
157 #[source]
158 error: serde_json::Error,
159 },
160 #[error("error building request body to {provider} API")]
161 BuildRequestBody {
162 provider: LanguageModelProviderName,
163 #[source]
164 error: http::Error,
165 },
166 #[error("error sending HTTP request to {provider} API")]
167 HttpSend {
168 provider: LanguageModelProviderName,
169 #[source]
170 error: anyhow::Error,
171 },
172 #[error("error deserializing {provider} API response")]
173 DeserializeResponse {
174 provider: LanguageModelProviderName,
175 #[source]
176 error: serde_json::Error,
177 },
178
179 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
180 #[error(transparent)]
181 Other(#[from] anyhow::Error),
182}
183
184impl LanguageModelCompletionError {
185 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
186 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
187 let upstream_status = error_json
188 .get("upstream_status")
189 .and_then(|v| v.as_u64())
190 .and_then(|status| u16::try_from(status).ok())
191 .and_then(|status| StatusCode::from_u16(status).ok())?;
192 let inner_message = error_json
193 .get("message")
194 .and_then(|v| v.as_str())
195 .unwrap_or(message)
196 .to_string();
197 Some((upstream_status, inner_message))
198 }
199
200 pub fn from_cloud_failure(
201 upstream_provider: LanguageModelProviderName,
202 code: String,
203 message: String,
204 retry_after: Option<Duration>,
205 ) -> Self {
206 if let Some(tokens) = parse_prompt_too_long(&message) {
207 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
208 // to be reported. This is a temporary workaround to handle this in the case where the
209 // token limit has been exceeded.
210 Self::PromptTooLarge {
211 tokens: Some(tokens),
212 }
213 } else if code == "upstream_http_error" {
214 if let Some((upstream_status, inner_message)) =
215 Self::parse_upstream_error_json(&message)
216 {
217 return Self::from_http_status(
218 upstream_provider,
219 upstream_status,
220 inner_message,
221 retry_after,
222 );
223 }
224 anyhow!("completion request failed, code: {code}, message: {message}").into()
225 } else if let Some(status_code) = code
226 .strip_prefix("upstream_http_")
227 .and_then(|code| StatusCode::from_str(code).ok())
228 {
229 Self::from_http_status(upstream_provider, status_code, message, retry_after)
230 } else if let Some(status_code) = code
231 .strip_prefix("http_")
232 .and_then(|code| StatusCode::from_str(code).ok())
233 {
234 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
235 } else {
236 anyhow!("completion request failed, code: {code}, message: {message}").into()
237 }
238 }
239
240 pub fn from_http_status(
241 provider: LanguageModelProviderName,
242 status_code: StatusCode,
243 message: String,
244 retry_after: Option<Duration>,
245 ) -> Self {
246 match status_code {
247 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
248 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
249 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
250 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
251 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
252 tokens: parse_prompt_too_long(&message),
253 },
254 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
255 provider,
256 retry_after,
257 },
258 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
259 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
260 provider,
261 retry_after,
262 },
263 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
264 provider,
265 retry_after,
266 },
267 _ => Self::HttpResponseError {
268 provider,
269 status_code,
270 message,
271 },
272 }
273 }
274}
275
276impl From<AnthropicError> for LanguageModelCompletionError {
277 fn from(error: AnthropicError) -> Self {
278 let provider = ANTHROPIC_PROVIDER_NAME;
279 match error {
280 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
281 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
282 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
283 AnthropicError::DeserializeResponse(error) => {
284 Self::DeserializeResponse { provider, error }
285 }
286 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
287 AnthropicError::HttpResponseError {
288 status_code,
289 message,
290 } => Self::HttpResponseError {
291 provider,
292 status_code,
293 message,
294 },
295 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
296 provider,
297 retry_after: Some(retry_after),
298 },
299 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
300 provider,
301 retry_after,
302 },
303 AnthropicError::ApiError(api_error) => api_error.into(),
304 }
305 }
306}
307
308impl From<anthropic::ApiError> for LanguageModelCompletionError {
309 fn from(error: anthropic::ApiError) -> Self {
310 use anthropic::ApiErrorCode::*;
311 let provider = ANTHROPIC_PROVIDER_NAME;
312 match error.code() {
313 Some(code) => match code {
314 InvalidRequestError => Self::BadRequestFormat {
315 provider,
316 message: error.message,
317 },
318 AuthenticationError => Self::AuthenticationError {
319 provider,
320 message: error.message,
321 },
322 PermissionError => Self::PermissionError {
323 provider,
324 message: error.message,
325 },
326 NotFoundError => Self::ApiEndpointNotFound { provider },
327 RequestTooLarge => Self::PromptTooLarge {
328 tokens: parse_prompt_too_long(&error.message),
329 },
330 RateLimitError => Self::RateLimitExceeded {
331 provider,
332 retry_after: None,
333 },
334 ApiError => Self::ApiInternalServerError {
335 provider,
336 message: error.message,
337 },
338 OverloadedError => Self::ServerOverloaded {
339 provider,
340 retry_after: None,
341 },
342 },
343 None => Self::Other(error.into()),
344 }
345 }
346}
347
348impl From<OpenRouterError> for LanguageModelCompletionError {
349 fn from(error: OpenRouterError) -> Self {
350 let provider = LanguageModelProviderName::new("OpenRouter");
351 match error {
352 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
353 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
354 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
355 OpenRouterError::DeserializeResponse(error) => {
356 Self::DeserializeResponse { provider, error }
357 }
358 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
359 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
360 provider,
361 retry_after: Some(retry_after),
362 },
363 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
364 provider,
365 retry_after,
366 },
367 OpenRouterError::ApiError(api_error) => api_error.into(),
368 }
369 }
370}
371
372impl From<open_router::ApiError> for LanguageModelCompletionError {
373 fn from(error: open_router::ApiError) -> Self {
374 use open_router::ApiErrorCode::*;
375 let provider = LanguageModelProviderName::new("OpenRouter");
376 match error.code {
377 InvalidRequestError => Self::BadRequestFormat {
378 provider,
379 message: error.message,
380 },
381 AuthenticationError => Self::AuthenticationError {
382 provider,
383 message: error.message,
384 },
385 PaymentRequiredError => Self::AuthenticationError {
386 provider,
387 message: format!("Payment required: {}", error.message),
388 },
389 PermissionError => Self::PermissionError {
390 provider,
391 message: error.message,
392 },
393 RequestTimedOut => Self::HttpResponseError {
394 provider,
395 status_code: StatusCode::REQUEST_TIMEOUT,
396 message: error.message,
397 },
398 RateLimitError => Self::RateLimitExceeded {
399 provider,
400 retry_after: None,
401 },
402 ApiError => Self::ApiInternalServerError {
403 provider,
404 message: error.message,
405 },
406 OverloadedError => Self::ServerOverloaded {
407 provider,
408 retry_after: None,
409 },
410 }
411 }
412}
413
414#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
415#[serde(rename_all = "snake_case")]
416pub enum StopReason {
417 EndTurn,
418 MaxTokens,
419 ToolUse,
420 Refusal,
421}
422
423#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
424pub struct TokenUsage {
425 #[serde(default, skip_serializing_if = "is_default")]
426 pub input_tokens: u64,
427 #[serde(default, skip_serializing_if = "is_default")]
428 pub output_tokens: u64,
429 #[serde(default, skip_serializing_if = "is_default")]
430 pub cache_creation_input_tokens: u64,
431 #[serde(default, skip_serializing_if = "is_default")]
432 pub cache_read_input_tokens: u64,
433}
434
435impl TokenUsage {
436 pub fn total_tokens(&self) -> u64 {
437 self.input_tokens
438 + self.output_tokens
439 + self.cache_read_input_tokens
440 + self.cache_creation_input_tokens
441 }
442}
443
444impl Add<TokenUsage> for TokenUsage {
445 type Output = Self;
446
447 fn add(self, other: Self) -> Self {
448 Self {
449 input_tokens: self.input_tokens + other.input_tokens,
450 output_tokens: self.output_tokens + other.output_tokens,
451 cache_creation_input_tokens: self.cache_creation_input_tokens
452 + other.cache_creation_input_tokens,
453 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
454 }
455 }
456}
457
458impl Sub<TokenUsage> for TokenUsage {
459 type Output = Self;
460
461 fn sub(self, other: Self) -> Self {
462 Self {
463 input_tokens: self.input_tokens - other.input_tokens,
464 output_tokens: self.output_tokens - other.output_tokens,
465 cache_creation_input_tokens: self.cache_creation_input_tokens
466 - other.cache_creation_input_tokens,
467 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
468 }
469 }
470}
471
472#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
473pub struct LanguageModelToolUseId(Arc<str>);
474
475impl fmt::Display for LanguageModelToolUseId {
476 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
477 write!(f, "{}", self.0)
478 }
479}
480
481impl<T> From<T> for LanguageModelToolUseId
482where
483 T: Into<Arc<str>>,
484{
485 fn from(value: T) -> Self {
486 Self(value.into())
487 }
488}
489
490#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
491pub struct LanguageModelToolUse {
492 pub id: LanguageModelToolUseId,
493 pub name: Arc<str>,
494 pub raw_input: String,
495 pub input: serde_json::Value,
496 pub is_input_complete: bool,
497}
498
499pub struct LanguageModelTextStream {
500 pub message_id: Option<String>,
501 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
502 // Has complete token usage after the stream has finished
503 pub last_token_usage: Arc<Mutex<TokenUsage>>,
504}
505
506impl Default for LanguageModelTextStream {
507 fn default() -> Self {
508 Self {
509 message_id: None,
510 stream: Box::pin(futures::stream::empty()),
511 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
512 }
513 }
514}
515
516pub trait LanguageModel: Send + Sync {
517 fn id(&self) -> LanguageModelId;
518 fn name(&self) -> LanguageModelName;
519 fn provider_id(&self) -> LanguageModelProviderId;
520 fn provider_name(&self) -> LanguageModelProviderName;
521 fn upstream_provider_id(&self) -> LanguageModelProviderId {
522 self.provider_id()
523 }
524 fn upstream_provider_name(&self) -> LanguageModelProviderName {
525 self.provider_name()
526 }
527
528 fn telemetry_id(&self) -> String;
529
530 fn api_key(&self, _cx: &App) -> Option<String> {
531 None
532 }
533
534 /// Whether this model supports images
535 fn supports_images(&self) -> bool;
536
537 /// Whether this model supports tools.
538 fn supports_tools(&self) -> bool;
539
540 /// Whether this model supports choosing which tool to use.
541 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
542
543 /// Returns whether this model supports "burn mode";
544 fn supports_burn_mode(&self) -> bool {
545 false
546 }
547
548 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
549 LanguageModelToolSchemaFormat::JsonSchema
550 }
551
552 fn max_token_count(&self) -> u64;
553 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
554 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
555 None
556 }
557 fn max_output_tokens(&self) -> Option<u64> {
558 None
559 }
560
561 fn count_tokens(
562 &self,
563 request: LanguageModelRequest,
564 cx: &App,
565 ) -> BoxFuture<'static, Result<u64>>;
566
567 fn stream_completion(
568 &self,
569 request: LanguageModelRequest,
570 cx: &AsyncApp,
571 ) -> BoxFuture<
572 'static,
573 Result<
574 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
575 LanguageModelCompletionError,
576 >,
577 >;
578
579 fn stream_completion_text(
580 &self,
581 request: LanguageModelRequest,
582 cx: &AsyncApp,
583 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
584 let future = self.stream_completion(request, cx);
585
586 async move {
587 let events = future.await?;
588 let mut events = events.fuse();
589 let mut message_id = None;
590 let mut first_item_text = None;
591 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
592
593 if let Some(first_event) = events.next().await {
594 match first_event {
595 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
596 message_id = Some(id);
597 }
598 Ok(LanguageModelCompletionEvent::Text(text)) => {
599 first_item_text = Some(text);
600 }
601 _ => (),
602 }
603 }
604
605 let stream = futures::stream::iter(first_item_text.map(Ok))
606 .chain(events.filter_map({
607 let last_token_usage = last_token_usage.clone();
608 move |result| {
609 let last_token_usage = last_token_usage.clone();
610 async move {
611 match result {
612 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
613 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
614 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
615 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
616 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
617 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
618 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
619 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
620 ..
621 }) => None,
622 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
623 *last_token_usage.lock() = token_usage;
624 None
625 }
626 Err(err) => Some(Err(err)),
627 }
628 }
629 }
630 }))
631 .boxed();
632
633 Ok(LanguageModelTextStream {
634 message_id,
635 stream,
636 last_token_usage,
637 })
638 }
639 .boxed()
640 }
641
642 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
643 None
644 }
645
646 #[cfg(any(test, feature = "test-support"))]
647 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
648 unimplemented!()
649 }
650}
651
652pub trait LanguageModelExt: LanguageModel {
653 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
654 match mode {
655 CompletionMode::Normal => self.max_token_count(),
656 CompletionMode::Max => self
657 .max_token_count_in_burn_mode()
658 .unwrap_or_else(|| self.max_token_count()),
659 }
660 }
661}
662impl LanguageModelExt for dyn LanguageModel {}
663
664/// An error that occurred when trying to authenticate the language model provider.
665#[derive(Debug, Error)]
666pub enum AuthenticateError {
667 #[error("connection refused")]
668 ConnectionRefused,
669 #[error("credentials not found")]
670 CredentialsNotFound,
671 #[error(transparent)]
672 Other(#[from] anyhow::Error),
673}
674
675pub trait LanguageModelProvider: 'static {
676 fn id(&self) -> LanguageModelProviderId;
677 fn name(&self) -> LanguageModelProviderName;
678 fn icon(&self) -> IconName {
679 IconName::ZedAssistant
680 }
681 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
682 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
683 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
684 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
685 Vec::new()
686 }
687 fn is_authenticated(&self, cx: &App) -> bool;
688 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
689 fn configuration_view(
690 &self,
691 target_agent: ConfigurationViewTargetAgent,
692 window: &mut Window,
693 cx: &mut App,
694 ) -> AnyView;
695 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
696}
697
698#[derive(Default, Clone)]
699pub enum ConfigurationViewTargetAgent {
700 #[default]
701 ZedAgent,
702 Other(SharedString),
703}
704
705#[derive(PartialEq, Eq)]
706pub enum LanguageModelProviderTosView {
707 /// When there are some past interactions in the Agent Panel.
708 ThreadEmptyState,
709 /// When there are no past interactions in the Agent Panel.
710 ThreadFreshStart,
711 TextThreadPopup,
712 Configuration,
713}
714
715pub trait LanguageModelProviderState: 'static {
716 type ObservableEntity;
717
718 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
719
720 fn subscribe<T: 'static>(
721 &self,
722 cx: &mut gpui::Context<T>,
723 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
724 ) -> Option<gpui::Subscription> {
725 let entity = self.observable_entity()?;
726 Some(cx.observe(&entity, move |this, _, cx| {
727 callback(this, cx);
728 }))
729 }
730}
731
732#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
733pub struct LanguageModelId(pub SharedString);
734
735#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
736pub struct LanguageModelName(pub SharedString);
737
738#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
739pub struct LanguageModelProviderId(pub SharedString);
740
741#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
742pub struct LanguageModelProviderName(pub SharedString);
743
744impl LanguageModelProviderId {
745 pub const fn new(id: &'static str) -> Self {
746 Self(SharedString::new_static(id))
747 }
748}
749
750impl LanguageModelProviderName {
751 pub const fn new(id: &'static str) -> Self {
752 Self(SharedString::new_static(id))
753 }
754}
755
756impl fmt::Display for LanguageModelProviderId {
757 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
758 write!(f, "{}", self.0)
759 }
760}
761
762impl fmt::Display for LanguageModelProviderName {
763 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
764 write!(f, "{}", self.0)
765 }
766}
767
768impl From<String> for LanguageModelId {
769 fn from(value: String) -> Self {
770 Self(SharedString::from(value))
771 }
772}
773
774impl From<String> for LanguageModelName {
775 fn from(value: String) -> Self {
776 Self(SharedString::from(value))
777 }
778}
779
780impl From<String> for LanguageModelProviderId {
781 fn from(value: String) -> Self {
782 Self(SharedString::from(value))
783 }
784}
785
786impl From<String> for LanguageModelProviderName {
787 fn from(value: String) -> Self {
788 Self(SharedString::from(value))
789 }
790}
791
792impl From<Arc<str>> for LanguageModelProviderId {
793 fn from(value: Arc<str>) -> Self {
794 Self(SharedString::from(value))
795 }
796}
797
798impl From<Arc<str>> for LanguageModelProviderName {
799 fn from(value: Arc<str>) -> Self {
800 Self(SharedString::from(value))
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
807
808 #[test]
809 fn test_from_cloud_failure_with_upstream_http_error() {
810 let error = LanguageModelCompletionError::from_cloud_failure(
811 String::from("anthropic").into(),
812 "upstream_http_error".to_string(),
813 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
814 None,
815 );
816
817 match error {
818 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
819 assert_eq!(provider.0, "anthropic");
820 }
821 _ => panic!(
822 "Expected ServerOverloaded error for 503 status, got: {:?}",
823 error
824 ),
825 }
826
827 let error = LanguageModelCompletionError::from_cloud_failure(
828 String::from("anthropic").into(),
829 "upstream_http_error".to_string(),
830 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
831 None,
832 );
833
834 match error {
835 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
836 assert_eq!(provider.0, "anthropic");
837 assert_eq!(message, "Internal server error");
838 }
839 _ => panic!(
840 "Expected ApiInternalServerError for 500 status, got: {:?}",
841 error
842 ),
843 }
844 }
845
846 #[test]
847 fn test_from_cloud_failure_with_standard_format() {
848 let error = LanguageModelCompletionError::from_cloud_failure(
849 String::from("anthropic").into(),
850 "upstream_http_503".to_string(),
851 "Service unavailable".to_string(),
852 None,
853 );
854
855 match error {
856 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
857 assert_eq!(provider.0, "anthropic");
858 }
859 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
860 }
861 }
862
863 #[test]
864 fn test_upstream_http_error_connection_timeout() {
865 let error = LanguageModelCompletionError::from_cloud_failure(
866 String::from("anthropic").into(),
867 "upstream_http_error".to_string(),
868 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
869 None,
870 );
871
872 match error {
873 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
874 assert_eq!(provider.0, "anthropic");
875 }
876 _ => panic!(
877 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
878 error
879 ),
880 }
881
882 let error = LanguageModelCompletionError::from_cloud_failure(
883 String::from("anthropic").into(),
884 "upstream_http_error".to_string(),
885 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
886 None,
887 );
888
889 match error {
890 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
891 assert_eq!(provider.0, "anthropic");
892 assert_eq!(
893 message,
894 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
895 );
896 }
897 _ => panic!(
898 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
899 error
900 ),
901 }
902 }
903}