1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7pub mod tool_schema;
8
9#[cfg(any(test, feature = "test-support"))]
10pub mod fake_provider;
11
12use anthropic::{AnthropicError, parse_prompt_too_long};
13use anyhow::{Result, anyhow};
14use client::Client;
15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
16use futures::FutureExt;
17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
19use http_client::{StatusCode, http};
20use icons::IconName;
21use open_router::OpenRouterError;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24pub use settings::LanguageModelCacheConfiguration;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
40
41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
42 LanguageModelProviderId::new("anthropic");
43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
44 LanguageModelProviderName::new("Anthropic");
45
46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Google AI");
49
50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("OpenAI");
53
54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
56
57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
59 LanguageModelProviderName::new("Zed");
60
61pub fn init(client: Arc<Client>, cx: &mut App) {
62 init_settings(cx);
63 RefreshLlmTokenListener::register(client, cx);
64}
65
66pub fn init_settings(cx: &mut App) {
67 registry::init(cx);
68}
69
70/// A completion event from a language model.
71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
72pub enum LanguageModelCompletionEvent {
73 StatusUpdate(CompletionRequestStatus),
74 Stop(StopReason),
75 Text(String),
76 Thinking {
77 text: String,
78 signature: Option<String>,
79 },
80 RedactedThinking {
81 data: String,
82 },
83 ToolUse(LanguageModelToolUse),
84 ToolUseJsonParseError {
85 id: LanguageModelToolUseId,
86 tool_name: Arc<str>,
87 raw_input: Arc<str>,
88 json_parse_error: String,
89 },
90 StartMessage {
91 message_id: String,
92 },
93 UsageUpdate(TokenUsage),
94}
95
96#[derive(Error, Debug)]
97pub enum LanguageModelCompletionError {
98 #[error("prompt too large for context window")]
99 PromptTooLarge { tokens: Option<u64> },
100 #[error("missing {provider} API key")]
101 NoApiKey { provider: LanguageModelProviderName },
102 #[error("{provider}'s API rate limit exceeded")]
103 RateLimitExceeded {
104 provider: LanguageModelProviderName,
105 retry_after: Option<Duration>,
106 },
107 #[error("{provider}'s API servers are overloaded right now")]
108 ServerOverloaded {
109 provider: LanguageModelProviderName,
110 retry_after: Option<Duration>,
111 },
112 #[error("{provider}'s API server reported an internal server error: {message}")]
113 ApiInternalServerError {
114 provider: LanguageModelProviderName,
115 message: String,
116 },
117 #[error("{message}")]
118 UpstreamProviderError {
119 message: String,
120 status: StatusCode,
121 retry_after: Option<Duration>,
122 },
123 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
124 HttpResponseError {
125 provider: LanguageModelProviderName,
126 status_code: StatusCode,
127 message: String,
128 },
129
130 // Client errors
131 #[error("invalid request format to {provider}'s API: {message}")]
132 BadRequestFormat {
133 provider: LanguageModelProviderName,
134 message: String,
135 },
136 #[error("authentication error with {provider}'s API: {message}")]
137 AuthenticationError {
138 provider: LanguageModelProviderName,
139 message: String,
140 },
141 #[error("permission error with {provider}'s API: {message}")]
142 PermissionError {
143 provider: LanguageModelProviderName,
144 message: String,
145 },
146 #[error("language model provider API endpoint not found")]
147 ApiEndpointNotFound { provider: LanguageModelProviderName },
148 #[error("I/O error reading response from {provider}'s API")]
149 ApiReadResponseError {
150 provider: LanguageModelProviderName,
151 #[source]
152 error: io::Error,
153 },
154 #[error("error serializing request to {provider} API")]
155 SerializeRequest {
156 provider: LanguageModelProviderName,
157 #[source]
158 error: serde_json::Error,
159 },
160 #[error("error building request body to {provider} API")]
161 BuildRequestBody {
162 provider: LanguageModelProviderName,
163 #[source]
164 error: http::Error,
165 },
166 #[error("error sending HTTP request to {provider} API")]
167 HttpSend {
168 provider: LanguageModelProviderName,
169 #[source]
170 error: anyhow::Error,
171 },
172 #[error("error deserializing {provider} API response")]
173 DeserializeResponse {
174 provider: LanguageModelProviderName,
175 #[source]
176 error: serde_json::Error,
177 },
178
179 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
180 #[error(transparent)]
181 Other(#[from] anyhow::Error),
182}
183
184impl LanguageModelCompletionError {
185 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
186 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
187 let upstream_status = error_json
188 .get("upstream_status")
189 .and_then(|v| v.as_u64())
190 .and_then(|status| u16::try_from(status).ok())
191 .and_then(|status| StatusCode::from_u16(status).ok())?;
192 let inner_message = error_json
193 .get("message")
194 .and_then(|v| v.as_str())
195 .unwrap_or(message)
196 .to_string();
197 Some((upstream_status, inner_message))
198 }
199
200 pub fn from_cloud_failure(
201 upstream_provider: LanguageModelProviderName,
202 code: String,
203 message: String,
204 retry_after: Option<Duration>,
205 ) -> Self {
206 if let Some(tokens) = parse_prompt_too_long(&message) {
207 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
208 // to be reported. This is a temporary workaround to handle this in the case where the
209 // token limit has been exceeded.
210 Self::PromptTooLarge {
211 tokens: Some(tokens),
212 }
213 } else if code == "upstream_http_error" {
214 if let Some((upstream_status, inner_message)) =
215 Self::parse_upstream_error_json(&message)
216 {
217 return Self::from_http_status(
218 upstream_provider,
219 upstream_status,
220 inner_message,
221 retry_after,
222 );
223 }
224 anyhow!("completion request failed, code: {code}, message: {message}").into()
225 } else if let Some(status_code) = code
226 .strip_prefix("upstream_http_")
227 .and_then(|code| StatusCode::from_str(code).ok())
228 {
229 Self::from_http_status(upstream_provider, status_code, message, retry_after)
230 } else if let Some(status_code) = code
231 .strip_prefix("http_")
232 .and_then(|code| StatusCode::from_str(code).ok())
233 {
234 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
235 } else {
236 anyhow!("completion request failed, code: {code}, message: {message}").into()
237 }
238 }
239
240 pub fn from_http_status(
241 provider: LanguageModelProviderName,
242 status_code: StatusCode,
243 message: String,
244 retry_after: Option<Duration>,
245 ) -> Self {
246 match status_code {
247 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
248 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
249 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
250 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
251 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
252 tokens: parse_prompt_too_long(&message),
253 },
254 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
255 provider,
256 retry_after,
257 },
258 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
259 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
260 provider,
261 retry_after,
262 },
263 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
264 provider,
265 retry_after,
266 },
267 _ => Self::HttpResponseError {
268 provider,
269 status_code,
270 message,
271 },
272 }
273 }
274}
275
276impl From<AnthropicError> for LanguageModelCompletionError {
277 fn from(error: AnthropicError) -> Self {
278 let provider = ANTHROPIC_PROVIDER_NAME;
279 match error {
280 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
281 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
282 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
283 AnthropicError::DeserializeResponse(error) => {
284 Self::DeserializeResponse { provider, error }
285 }
286 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
287 AnthropicError::HttpResponseError {
288 status_code,
289 message,
290 } => Self::HttpResponseError {
291 provider,
292 status_code,
293 message,
294 },
295 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
296 provider,
297 retry_after: Some(retry_after),
298 },
299 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
300 provider,
301 retry_after,
302 },
303 AnthropicError::ApiError(api_error) => api_error.into(),
304 }
305 }
306}
307
308impl From<anthropic::ApiError> for LanguageModelCompletionError {
309 fn from(error: anthropic::ApiError) -> Self {
310 use anthropic::ApiErrorCode::*;
311 let provider = ANTHROPIC_PROVIDER_NAME;
312 match error.code() {
313 Some(code) => match code {
314 InvalidRequestError => Self::BadRequestFormat {
315 provider,
316 message: error.message,
317 },
318 AuthenticationError => Self::AuthenticationError {
319 provider,
320 message: error.message,
321 },
322 PermissionError => Self::PermissionError {
323 provider,
324 message: error.message,
325 },
326 NotFoundError => Self::ApiEndpointNotFound { provider },
327 RequestTooLarge => Self::PromptTooLarge {
328 tokens: parse_prompt_too_long(&error.message),
329 },
330 RateLimitError => Self::RateLimitExceeded {
331 provider,
332 retry_after: None,
333 },
334 ApiError => Self::ApiInternalServerError {
335 provider,
336 message: error.message,
337 },
338 OverloadedError => Self::ServerOverloaded {
339 provider,
340 retry_after: None,
341 },
342 },
343 None => Self::Other(error.into()),
344 }
345 }
346}
347
348impl From<OpenRouterError> for LanguageModelCompletionError {
349 fn from(error: OpenRouterError) -> Self {
350 let provider = LanguageModelProviderName::new("OpenRouter");
351 match error {
352 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
353 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
354 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
355 OpenRouterError::DeserializeResponse(error) => {
356 Self::DeserializeResponse { provider, error }
357 }
358 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
359 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
360 provider,
361 retry_after: Some(retry_after),
362 },
363 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
364 provider,
365 retry_after,
366 },
367 OpenRouterError::ApiError(api_error) => api_error.into(),
368 }
369 }
370}
371
372impl From<open_router::ApiError> for LanguageModelCompletionError {
373 fn from(error: open_router::ApiError) -> Self {
374 use open_router::ApiErrorCode::*;
375 let provider = LanguageModelProviderName::new("OpenRouter");
376 match error.code {
377 InvalidRequestError => Self::BadRequestFormat {
378 provider,
379 message: error.message,
380 },
381 AuthenticationError => Self::AuthenticationError {
382 provider,
383 message: error.message,
384 },
385 PaymentRequiredError => Self::AuthenticationError {
386 provider,
387 message: format!("Payment required: {}", error.message),
388 },
389 PermissionError => Self::PermissionError {
390 provider,
391 message: error.message,
392 },
393 RequestTimedOut => Self::HttpResponseError {
394 provider,
395 status_code: StatusCode::REQUEST_TIMEOUT,
396 message: error.message,
397 },
398 RateLimitError => Self::RateLimitExceeded {
399 provider,
400 retry_after: None,
401 },
402 ApiError => Self::ApiInternalServerError {
403 provider,
404 message: error.message,
405 },
406 OverloadedError => Self::ServerOverloaded {
407 provider,
408 retry_after: None,
409 },
410 }
411 }
412}
413
414#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
415#[serde(rename_all = "snake_case")]
416pub enum StopReason {
417 EndTurn,
418 MaxTokens,
419 ToolUse,
420 Refusal,
421}
422
423#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
424pub struct TokenUsage {
425 #[serde(default, skip_serializing_if = "is_default")]
426 pub input_tokens: u64,
427 #[serde(default, skip_serializing_if = "is_default")]
428 pub output_tokens: u64,
429 #[serde(default, skip_serializing_if = "is_default")]
430 pub cache_creation_input_tokens: u64,
431 #[serde(default, skip_serializing_if = "is_default")]
432 pub cache_read_input_tokens: u64,
433}
434
435impl TokenUsage {
436 pub fn total_tokens(&self) -> u64 {
437 self.input_tokens
438 + self.output_tokens
439 + self.cache_read_input_tokens
440 + self.cache_creation_input_tokens
441 }
442}
443
444impl Add<TokenUsage> for TokenUsage {
445 type Output = Self;
446
447 fn add(self, other: Self) -> Self {
448 Self {
449 input_tokens: self.input_tokens + other.input_tokens,
450 output_tokens: self.output_tokens + other.output_tokens,
451 cache_creation_input_tokens: self.cache_creation_input_tokens
452 + other.cache_creation_input_tokens,
453 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
454 }
455 }
456}
457
458impl Sub<TokenUsage> for TokenUsage {
459 type Output = Self;
460
461 fn sub(self, other: Self) -> Self {
462 Self {
463 input_tokens: self.input_tokens - other.input_tokens,
464 output_tokens: self.output_tokens - other.output_tokens,
465 cache_creation_input_tokens: self.cache_creation_input_tokens
466 - other.cache_creation_input_tokens,
467 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
468 }
469 }
470}
471
472#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
473pub struct LanguageModelToolUseId(Arc<str>);
474
475impl fmt::Display for LanguageModelToolUseId {
476 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
477 write!(f, "{}", self.0)
478 }
479}
480
481impl<T> From<T> for LanguageModelToolUseId
482where
483 T: Into<Arc<str>>,
484{
485 fn from(value: T) -> Self {
486 Self(value.into())
487 }
488}
489
490#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
491pub struct LanguageModelToolUse {
492 pub id: LanguageModelToolUseId,
493 pub name: Arc<str>,
494 pub raw_input: String,
495 pub input: serde_json::Value,
496 pub is_input_complete: bool,
497 /// Thought signature the model sent us. Some models require that this
498 /// signature be preserved and sent back in conversation history for validation.
499 pub thought_signature: Option<String>,
500}
501
502pub struct LanguageModelTextStream {
503 pub message_id: Option<String>,
504 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
505 // Has complete token usage after the stream has finished
506 pub last_token_usage: Arc<Mutex<TokenUsage>>,
507}
508
509impl Default for LanguageModelTextStream {
510 fn default() -> Self {
511 Self {
512 message_id: None,
513 stream: Box::pin(futures::stream::empty()),
514 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
515 }
516 }
517}
518
519pub trait LanguageModel: Send + Sync {
520 fn id(&self) -> LanguageModelId;
521 fn name(&self) -> LanguageModelName;
522 fn provider_id(&self) -> LanguageModelProviderId;
523 fn provider_name(&self) -> LanguageModelProviderName;
524 fn upstream_provider_id(&self) -> LanguageModelProviderId {
525 self.provider_id()
526 }
527 fn upstream_provider_name(&self) -> LanguageModelProviderName {
528 self.provider_name()
529 }
530
531 fn telemetry_id(&self) -> String;
532
533 fn api_key(&self, _cx: &App) -> Option<String> {
534 None
535 }
536
537 /// Whether this model supports images
538 fn supports_images(&self) -> bool;
539
540 /// Whether this model supports tools.
541 fn supports_tools(&self) -> bool;
542
543 /// Whether this model supports choosing which tool to use.
544 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
545
546 /// Returns whether this model supports "burn mode";
547 fn supports_burn_mode(&self) -> bool {
548 false
549 }
550
551 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
552 LanguageModelToolSchemaFormat::JsonSchema
553 }
554
555 fn max_token_count(&self) -> u64;
556 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
557 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
558 None
559 }
560 fn max_output_tokens(&self) -> Option<u64> {
561 None
562 }
563
564 fn count_tokens(
565 &self,
566 request: LanguageModelRequest,
567 cx: &App,
568 ) -> BoxFuture<'static, Result<u64>>;
569
570 fn stream_completion(
571 &self,
572 request: LanguageModelRequest,
573 cx: &AsyncApp,
574 ) -> BoxFuture<
575 'static,
576 Result<
577 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
578 LanguageModelCompletionError,
579 >,
580 >;
581
582 fn stream_completion_text(
583 &self,
584 request: LanguageModelRequest,
585 cx: &AsyncApp,
586 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
587 let future = self.stream_completion(request, cx);
588
589 async move {
590 let events = future.await?;
591 let mut events = events.fuse();
592 let mut message_id = None;
593 let mut first_item_text = None;
594 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
595
596 if let Some(first_event) = events.next().await {
597 match first_event {
598 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
599 message_id = Some(id);
600 }
601 Ok(LanguageModelCompletionEvent::Text(text)) => {
602 first_item_text = Some(text);
603 }
604 _ => (),
605 }
606 }
607
608 let stream = futures::stream::iter(first_item_text.map(Ok))
609 .chain(events.filter_map({
610 let last_token_usage = last_token_usage.clone();
611 move |result| {
612 let last_token_usage = last_token_usage.clone();
613 async move {
614 match result {
615 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
616 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
617 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
618 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
619 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
620 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
621 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
622 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
623 ..
624 }) => None,
625 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
626 *last_token_usage.lock() = token_usage;
627 None
628 }
629 Err(err) => Some(Err(err)),
630 }
631 }
632 }
633 }))
634 .boxed();
635
636 Ok(LanguageModelTextStream {
637 message_id,
638 stream,
639 last_token_usage,
640 })
641 }
642 .boxed()
643 }
644
645 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
646 None
647 }
648
649 #[cfg(any(test, feature = "test-support"))]
650 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
651 unimplemented!()
652 }
653}
654
655pub trait LanguageModelExt: LanguageModel {
656 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
657 match mode {
658 CompletionMode::Normal => self.max_token_count(),
659 CompletionMode::Max => self
660 .max_token_count_in_burn_mode()
661 .unwrap_or_else(|| self.max_token_count()),
662 }
663 }
664}
665impl LanguageModelExt for dyn LanguageModel {}
666
667/// An error that occurred when trying to authenticate the language model provider.
668#[derive(Debug, Error)]
669pub enum AuthenticateError {
670 #[error("connection refused")]
671 ConnectionRefused,
672 #[error("credentials not found")]
673 CredentialsNotFound,
674 #[error(transparent)]
675 Other(#[from] anyhow::Error),
676}
677
678pub trait LanguageModelProvider: 'static {
679 fn id(&self) -> LanguageModelProviderId;
680 fn name(&self) -> LanguageModelProviderName;
681 fn icon(&self) -> IconName {
682 IconName::ZedAssistant
683 }
684 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
685 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
686 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
687 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
688 Vec::new()
689 }
690 fn is_authenticated(&self, cx: &App) -> bool;
691 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
692 fn configuration_view(
693 &self,
694 target_agent: ConfigurationViewTargetAgent,
695 window: &mut Window,
696 cx: &mut App,
697 ) -> AnyView;
698 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
699}
700
701#[derive(Default, Clone)]
702pub enum ConfigurationViewTargetAgent {
703 #[default]
704 ZedAgent,
705 Other(SharedString),
706}
707
708#[derive(PartialEq, Eq)]
709pub enum LanguageModelProviderTosView {
710 /// When there are some past interactions in the Agent Panel.
711 ThreadEmptyState,
712 /// When there are no past interactions in the Agent Panel.
713 ThreadFreshStart,
714 TextThreadPopup,
715 Configuration,
716}
717
718pub trait LanguageModelProviderState: 'static {
719 type ObservableEntity;
720
721 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
722
723 fn subscribe<T: 'static>(
724 &self,
725 cx: &mut gpui::Context<T>,
726 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
727 ) -> Option<gpui::Subscription> {
728 let entity = self.observable_entity()?;
729 Some(cx.observe(&entity, move |this, _, cx| {
730 callback(this, cx);
731 }))
732 }
733}
734
735#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
736pub struct LanguageModelId(pub SharedString);
737
738#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
739pub struct LanguageModelName(pub SharedString);
740
741#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
742pub struct LanguageModelProviderId(pub SharedString);
743
744#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
745pub struct LanguageModelProviderName(pub SharedString);
746
747impl LanguageModelProviderId {
748 pub const fn new(id: &'static str) -> Self {
749 Self(SharedString::new_static(id))
750 }
751}
752
753impl LanguageModelProviderName {
754 pub const fn new(id: &'static str) -> Self {
755 Self(SharedString::new_static(id))
756 }
757}
758
759impl fmt::Display for LanguageModelProviderId {
760 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
761 write!(f, "{}", self.0)
762 }
763}
764
765impl fmt::Display for LanguageModelProviderName {
766 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
767 write!(f, "{}", self.0)
768 }
769}
770
771impl From<String> for LanguageModelId {
772 fn from(value: String) -> Self {
773 Self(SharedString::from(value))
774 }
775}
776
777impl From<String> for LanguageModelName {
778 fn from(value: String) -> Self {
779 Self(SharedString::from(value))
780 }
781}
782
783impl From<String> for LanguageModelProviderId {
784 fn from(value: String) -> Self {
785 Self(SharedString::from(value))
786 }
787}
788
789impl From<String> for LanguageModelProviderName {
790 fn from(value: String) -> Self {
791 Self(SharedString::from(value))
792 }
793}
794
795impl From<Arc<str>> for LanguageModelProviderId {
796 fn from(value: Arc<str>) -> Self {
797 Self(SharedString::from(value))
798 }
799}
800
801impl From<Arc<str>> for LanguageModelProviderName {
802 fn from(value: Arc<str>) -> Self {
803 Self(SharedString::from(value))
804 }
805}
806
807#[cfg(test)]
808mod tests {
809 use super::*;
810
811 #[test]
812 fn test_from_cloud_failure_with_upstream_http_error() {
813 let error = LanguageModelCompletionError::from_cloud_failure(
814 String::from("anthropic").into(),
815 "upstream_http_error".to_string(),
816 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
817 None,
818 );
819
820 match error {
821 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
822 assert_eq!(provider.0, "anthropic");
823 }
824 _ => panic!(
825 "Expected ServerOverloaded error for 503 status, got: {:?}",
826 error
827 ),
828 }
829
830 let error = LanguageModelCompletionError::from_cloud_failure(
831 String::from("anthropic").into(),
832 "upstream_http_error".to_string(),
833 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
834 None,
835 );
836
837 match error {
838 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
839 assert_eq!(provider.0, "anthropic");
840 assert_eq!(message, "Internal server error");
841 }
842 _ => panic!(
843 "Expected ApiInternalServerError for 500 status, got: {:?}",
844 error
845 ),
846 }
847 }
848
849 #[test]
850 fn test_from_cloud_failure_with_standard_format() {
851 let error = LanguageModelCompletionError::from_cloud_failure(
852 String::from("anthropic").into(),
853 "upstream_http_503".to_string(),
854 "Service unavailable".to_string(),
855 None,
856 );
857
858 match error {
859 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
860 assert_eq!(provider.0, "anthropic");
861 }
862 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
863 }
864 }
865
866 #[test]
867 fn test_upstream_http_error_connection_timeout() {
868 let error = LanguageModelCompletionError::from_cloud_failure(
869 String::from("anthropic").into(),
870 "upstream_http_error".to_string(),
871 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
872 None,
873 );
874
875 match error {
876 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
877 assert_eq!(provider.0, "anthropic");
878 }
879 _ => panic!(
880 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
881 error
882 ),
883 }
884
885 let error = LanguageModelCompletionError::from_cloud_failure(
886 String::from("anthropic").into(),
887 "upstream_http_error".to_string(),
888 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
889 None,
890 );
891
892 match error {
893 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
894 assert_eq!(provider.0, "anthropic");
895 assert_eq!(
896 message,
897 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
898 );
899 }
900 _ => panic!(
901 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
902 error
903 ),
904 }
905 }
906
907 #[test]
908 fn test_language_model_tool_use_serializes_with_signature() {
909 use serde_json::json;
910
911 let tool_use = LanguageModelToolUse {
912 id: LanguageModelToolUseId::from("test_id"),
913 name: "test_tool".into(),
914 raw_input: json!({"arg": "value"}).to_string(),
915 input: json!({"arg": "value"}),
916 is_input_complete: true,
917 thought_signature: Some("test_signature".to_string()),
918 };
919
920 let serialized = serde_json::to_value(&tool_use).unwrap();
921
922 assert_eq!(serialized["id"], "test_id");
923 assert_eq!(serialized["name"], "test_tool");
924 assert_eq!(serialized["thought_signature"], "test_signature");
925 }
926
927 #[test]
928 fn test_language_model_tool_use_deserializes_with_missing_signature() {
929 use serde_json::json;
930
931 let json = json!({
932 "id": "test_id",
933 "name": "test_tool",
934 "raw_input": "{\"arg\":\"value\"}",
935 "input": {"arg": "value"},
936 "is_input_complete": true
937 });
938
939 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
940
941 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
942 assert_eq!(tool_use.name.as_ref(), "test_tool");
943 assert_eq!(tool_use.thought_signature, None);
944 }
945
946 #[test]
947 fn test_language_model_tool_use_round_trip_with_signature() {
948 use serde_json::json;
949
950 let original = LanguageModelToolUse {
951 id: LanguageModelToolUseId::from("round_trip_id"),
952 name: "round_trip_tool".into(),
953 raw_input: json!({"key": "value"}).to_string(),
954 input: json!({"key": "value"}),
955 is_input_complete: true,
956 thought_signature: Some("round_trip_sig".to_string()),
957 };
958
959 let serialized = serde_json::to_value(&original).unwrap();
960 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
961
962 assert_eq!(deserialized.id, original.id);
963 assert_eq!(deserialized.name, original.name);
964 assert_eq!(deserialized.thought_signature, original.thought_signature);
965 }
966
967 #[test]
968 fn test_language_model_tool_use_round_trip_without_signature() {
969 use serde_json::json;
970
971 let original = LanguageModelToolUse {
972 id: LanguageModelToolUseId::from("no_sig_id"),
973 name: "no_sig_tool".into(),
974 raw_input: json!({"key": "value"}).to_string(),
975 input: json!({"key": "value"}),
976 is_input_complete: true,
977 thought_signature: None,
978 };
979
980 let serialized = serde_json::to_value(&original).unwrap();
981 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
982
983 assert_eq!(deserialized.id, original.id);
984 assert_eq!(deserialized.name, original.name);
985 assert_eq!(deserialized.thought_signature, None);
986 }
987}