1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7pub mod tool_schema;
8
9#[cfg(any(test, feature = "test-support"))]
10pub mod fake_provider;
11
12use anthropic::{AnthropicError, parse_prompt_too_long};
13use anyhow::{Result, anyhow};
14use client::Client;
15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
16use futures::FutureExt;
17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
19use http_client::{StatusCode, http};
20use icons::IconName;
21use open_router::OpenRouterError;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24pub use settings::LanguageModelCacheConfiguration;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
40
41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
42 LanguageModelProviderId::new("anthropic");
43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
44 LanguageModelProviderName::new("Anthropic");
45
46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Google AI");
49
50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("OpenAI");
53
54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
56
57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
59 LanguageModelProviderName::new("Zed");
60
61pub fn init(client: Arc<Client>, cx: &mut App) {
62 init_settings(cx);
63 RefreshLlmTokenListener::register(client, cx);
64}
65
66pub fn init_settings(cx: &mut App) {
67 registry::init(cx);
68}
69
70/// A completion event from a language model.
71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
72pub enum LanguageModelCompletionEvent {
73 StatusUpdate(CompletionRequestStatus),
74 Stop(StopReason),
75 Text(String),
76 Thinking {
77 text: String,
78 signature: Option<String>,
79 },
80 RedactedThinking {
81 data: String,
82 },
83 ToolUse(LanguageModelToolUse),
84 ToolUseJsonParseError {
85 id: LanguageModelToolUseId,
86 tool_name: Arc<str>,
87 raw_input: Arc<str>,
88 json_parse_error: String,
89 },
90 StartMessage {
91 message_id: String,
92 },
93 ReasoningDetails(serde_json::Value),
94 UsageUpdate(TokenUsage),
95}
96
97#[derive(Error, Debug)]
98pub enum LanguageModelCompletionError {
99 #[error("prompt too large for context window")]
100 PromptTooLarge { tokens: Option<u64> },
101 #[error("missing {provider} API key")]
102 NoApiKey { provider: LanguageModelProviderName },
103 #[error("{provider}'s API rate limit exceeded")]
104 RateLimitExceeded {
105 provider: LanguageModelProviderName,
106 retry_after: Option<Duration>,
107 },
108 #[error("{provider}'s API servers are overloaded right now")]
109 ServerOverloaded {
110 provider: LanguageModelProviderName,
111 retry_after: Option<Duration>,
112 },
113 #[error("{provider}'s API server reported an internal server error: {message}")]
114 ApiInternalServerError {
115 provider: LanguageModelProviderName,
116 message: String,
117 },
118 #[error("{message}")]
119 UpstreamProviderError {
120 message: String,
121 status: StatusCode,
122 retry_after: Option<Duration>,
123 },
124 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
125 HttpResponseError {
126 provider: LanguageModelProviderName,
127 status_code: StatusCode,
128 message: String,
129 },
130
131 // Client errors
132 #[error("invalid request format to {provider}'s API: {message}")]
133 BadRequestFormat {
134 provider: LanguageModelProviderName,
135 message: String,
136 },
137 #[error("authentication error with {provider}'s API: {message}")]
138 AuthenticationError {
139 provider: LanguageModelProviderName,
140 message: String,
141 },
142 #[error("permission error with {provider}'s API: {message}")]
143 PermissionError {
144 provider: LanguageModelProviderName,
145 message: String,
146 },
147 #[error("language model provider API endpoint not found")]
148 ApiEndpointNotFound { provider: LanguageModelProviderName },
149 #[error("I/O error reading response from {provider}'s API")]
150 ApiReadResponseError {
151 provider: LanguageModelProviderName,
152 #[source]
153 error: io::Error,
154 },
155 #[error("error serializing request to {provider} API")]
156 SerializeRequest {
157 provider: LanguageModelProviderName,
158 #[source]
159 error: serde_json::Error,
160 },
161 #[error("error building request body to {provider} API")]
162 BuildRequestBody {
163 provider: LanguageModelProviderName,
164 #[source]
165 error: http::Error,
166 },
167 #[error("error sending HTTP request to {provider} API")]
168 HttpSend {
169 provider: LanguageModelProviderName,
170 #[source]
171 error: anyhow::Error,
172 },
173 #[error("error deserializing {provider} API response")]
174 DeserializeResponse {
175 provider: LanguageModelProviderName,
176 #[source]
177 error: serde_json::Error,
178 },
179
180 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
181 #[error(transparent)]
182 Other(#[from] anyhow::Error),
183}
184
185impl LanguageModelCompletionError {
186 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
187 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
188 let upstream_status = error_json
189 .get("upstream_status")
190 .and_then(|v| v.as_u64())
191 .and_then(|status| u16::try_from(status).ok())
192 .and_then(|status| StatusCode::from_u16(status).ok())?;
193 let inner_message = error_json
194 .get("message")
195 .and_then(|v| v.as_str())
196 .unwrap_or(message)
197 .to_string();
198 Some((upstream_status, inner_message))
199 }
200
201 pub fn from_cloud_failure(
202 upstream_provider: LanguageModelProviderName,
203 code: String,
204 message: String,
205 retry_after: Option<Duration>,
206 ) -> Self {
207 if let Some(tokens) = parse_prompt_too_long(&message) {
208 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
209 // to be reported. This is a temporary workaround to handle this in the case where the
210 // token limit has been exceeded.
211 Self::PromptTooLarge {
212 tokens: Some(tokens),
213 }
214 } else if code == "upstream_http_error" {
215 if let Some((upstream_status, inner_message)) =
216 Self::parse_upstream_error_json(&message)
217 {
218 return Self::from_http_status(
219 upstream_provider,
220 upstream_status,
221 inner_message,
222 retry_after,
223 );
224 }
225 anyhow!("completion request failed, code: {code}, message: {message}").into()
226 } else if let Some(status_code) = code
227 .strip_prefix("upstream_http_")
228 .and_then(|code| StatusCode::from_str(code).ok())
229 {
230 Self::from_http_status(upstream_provider, status_code, message, retry_after)
231 } else if let Some(status_code) = code
232 .strip_prefix("http_")
233 .and_then(|code| StatusCode::from_str(code).ok())
234 {
235 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
236 } else {
237 anyhow!("completion request failed, code: {code}, message: {message}").into()
238 }
239 }
240
241 pub fn from_http_status(
242 provider: LanguageModelProviderName,
243 status_code: StatusCode,
244 message: String,
245 retry_after: Option<Duration>,
246 ) -> Self {
247 match status_code {
248 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
249 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
250 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
251 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
252 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
253 tokens: parse_prompt_too_long(&message),
254 },
255 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
256 provider,
257 retry_after,
258 },
259 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
260 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
261 provider,
262 retry_after,
263 },
264 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
265 provider,
266 retry_after,
267 },
268 _ => Self::HttpResponseError {
269 provider,
270 status_code,
271 message,
272 },
273 }
274 }
275}
276
277impl From<AnthropicError> for LanguageModelCompletionError {
278 fn from(error: AnthropicError) -> Self {
279 let provider = ANTHROPIC_PROVIDER_NAME;
280 match error {
281 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
282 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
283 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
284 AnthropicError::DeserializeResponse(error) => {
285 Self::DeserializeResponse { provider, error }
286 }
287 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
288 AnthropicError::HttpResponseError {
289 status_code,
290 message,
291 } => Self::HttpResponseError {
292 provider,
293 status_code,
294 message,
295 },
296 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
297 provider,
298 retry_after: Some(retry_after),
299 },
300 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
301 provider,
302 retry_after,
303 },
304 AnthropicError::ApiError(api_error) => api_error.into(),
305 }
306 }
307}
308
309impl From<anthropic::ApiError> for LanguageModelCompletionError {
310 fn from(error: anthropic::ApiError) -> Self {
311 use anthropic::ApiErrorCode::*;
312 let provider = ANTHROPIC_PROVIDER_NAME;
313 match error.code() {
314 Some(code) => match code {
315 InvalidRequestError => Self::BadRequestFormat {
316 provider,
317 message: error.message,
318 },
319 AuthenticationError => Self::AuthenticationError {
320 provider,
321 message: error.message,
322 },
323 PermissionError => Self::PermissionError {
324 provider,
325 message: error.message,
326 },
327 NotFoundError => Self::ApiEndpointNotFound { provider },
328 RequestTooLarge => Self::PromptTooLarge {
329 tokens: parse_prompt_too_long(&error.message),
330 },
331 RateLimitError => Self::RateLimitExceeded {
332 provider,
333 retry_after: None,
334 },
335 ApiError => Self::ApiInternalServerError {
336 provider,
337 message: error.message,
338 },
339 OverloadedError => Self::ServerOverloaded {
340 provider,
341 retry_after: None,
342 },
343 },
344 None => Self::Other(error.into()),
345 }
346 }
347}
348
349impl From<OpenRouterError> for LanguageModelCompletionError {
350 fn from(error: OpenRouterError) -> Self {
351 let provider = LanguageModelProviderName::new("OpenRouter");
352 match error {
353 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
354 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
355 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
356 OpenRouterError::DeserializeResponse(error) => {
357 Self::DeserializeResponse { provider, error }
358 }
359 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
360 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
361 provider,
362 retry_after: Some(retry_after),
363 },
364 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
365 provider,
366 retry_after,
367 },
368 OpenRouterError::ApiError(api_error) => api_error.into(),
369 }
370 }
371}
372
373impl From<open_router::ApiError> for LanguageModelCompletionError {
374 fn from(error: open_router::ApiError) -> Self {
375 use open_router::ApiErrorCode::*;
376 let provider = LanguageModelProviderName::new("OpenRouter");
377 match error.code {
378 InvalidRequestError => Self::BadRequestFormat {
379 provider,
380 message: error.message,
381 },
382 AuthenticationError => Self::AuthenticationError {
383 provider,
384 message: error.message,
385 },
386 PaymentRequiredError => Self::AuthenticationError {
387 provider,
388 message: format!("Payment required: {}", error.message),
389 },
390 PermissionError => Self::PermissionError {
391 provider,
392 message: error.message,
393 },
394 RequestTimedOut => Self::HttpResponseError {
395 provider,
396 status_code: StatusCode::REQUEST_TIMEOUT,
397 message: error.message,
398 },
399 RateLimitError => Self::RateLimitExceeded {
400 provider,
401 retry_after: None,
402 },
403 ApiError => Self::ApiInternalServerError {
404 provider,
405 message: error.message,
406 },
407 OverloadedError => Self::ServerOverloaded {
408 provider,
409 retry_after: None,
410 },
411 }
412 }
413}
414
415#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
416#[serde(rename_all = "snake_case")]
417pub enum StopReason {
418 EndTurn,
419 MaxTokens,
420 ToolUse,
421 Refusal,
422}
423
424#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
425pub struct TokenUsage {
426 #[serde(default, skip_serializing_if = "is_default")]
427 pub input_tokens: u64,
428 #[serde(default, skip_serializing_if = "is_default")]
429 pub output_tokens: u64,
430 #[serde(default, skip_serializing_if = "is_default")]
431 pub cache_creation_input_tokens: u64,
432 #[serde(default, skip_serializing_if = "is_default")]
433 pub cache_read_input_tokens: u64,
434}
435
436impl TokenUsage {
437 pub fn total_tokens(&self) -> u64 {
438 self.input_tokens
439 + self.output_tokens
440 + self.cache_read_input_tokens
441 + self.cache_creation_input_tokens
442 }
443}
444
445impl Add<TokenUsage> for TokenUsage {
446 type Output = Self;
447
448 fn add(self, other: Self) -> Self {
449 Self {
450 input_tokens: self.input_tokens + other.input_tokens,
451 output_tokens: self.output_tokens + other.output_tokens,
452 cache_creation_input_tokens: self.cache_creation_input_tokens
453 + other.cache_creation_input_tokens,
454 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
455 }
456 }
457}
458
459impl Sub<TokenUsage> for TokenUsage {
460 type Output = Self;
461
462 fn sub(self, other: Self) -> Self {
463 Self {
464 input_tokens: self.input_tokens - other.input_tokens,
465 output_tokens: self.output_tokens - other.output_tokens,
466 cache_creation_input_tokens: self.cache_creation_input_tokens
467 - other.cache_creation_input_tokens,
468 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
469 }
470 }
471}
472
473#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
474pub struct LanguageModelToolUseId(Arc<str>);
475
476impl fmt::Display for LanguageModelToolUseId {
477 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478 write!(f, "{}", self.0)
479 }
480}
481
482impl<T> From<T> for LanguageModelToolUseId
483where
484 T: Into<Arc<str>>,
485{
486 fn from(value: T) -> Self {
487 Self(value.into())
488 }
489}
490
491#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
492pub struct LanguageModelToolUse {
493 pub id: LanguageModelToolUseId,
494 pub name: Arc<str>,
495 pub raw_input: String,
496 pub input: serde_json::Value,
497 pub is_input_complete: bool,
498 /// Thought signature the model sent us. Some models require that this
499 /// signature be preserved and sent back in conversation history for validation.
500 pub thought_signature: Option<String>,
501}
502
503pub struct LanguageModelTextStream {
504 pub message_id: Option<String>,
505 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
506 // Has complete token usage after the stream has finished
507 pub last_token_usage: Arc<Mutex<TokenUsage>>,
508}
509
510impl Default for LanguageModelTextStream {
511 fn default() -> Self {
512 Self {
513 message_id: None,
514 stream: Box::pin(futures::stream::empty()),
515 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
516 }
517 }
518}
519
520pub trait LanguageModel: Send + Sync {
521 fn id(&self) -> LanguageModelId;
522 fn name(&self) -> LanguageModelName;
523 fn provider_id(&self) -> LanguageModelProviderId;
524 fn provider_name(&self) -> LanguageModelProviderName;
525 fn upstream_provider_id(&self) -> LanguageModelProviderId {
526 self.provider_id()
527 }
528 fn upstream_provider_name(&self) -> LanguageModelProviderName {
529 self.provider_name()
530 }
531
532 fn telemetry_id(&self) -> String;
533
534 fn api_key(&self, _cx: &App) -> Option<String> {
535 None
536 }
537
538 /// Whether this model supports images
539 fn supports_images(&self) -> bool;
540
541 /// Whether this model supports tools.
542 fn supports_tools(&self) -> bool;
543
544 /// Whether this model supports choosing which tool to use.
545 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
546
547 /// Returns whether this model supports "burn mode";
548 fn supports_burn_mode(&self) -> bool {
549 false
550 }
551
552 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
553 LanguageModelToolSchemaFormat::JsonSchema
554 }
555
556 fn max_token_count(&self) -> u64;
557 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
558 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
559 None
560 }
561 fn max_output_tokens(&self) -> Option<u64> {
562 None
563 }
564
565 fn count_tokens(
566 &self,
567 request: LanguageModelRequest,
568 cx: &App,
569 ) -> BoxFuture<'static, Result<u64>>;
570
571 fn stream_completion(
572 &self,
573 request: LanguageModelRequest,
574 cx: &AsyncApp,
575 ) -> BoxFuture<
576 'static,
577 Result<
578 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
579 LanguageModelCompletionError,
580 >,
581 >;
582
583 fn stream_completion_text(
584 &self,
585 request: LanguageModelRequest,
586 cx: &AsyncApp,
587 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
588 let future = self.stream_completion(request, cx);
589
590 async move {
591 let events = future.await?;
592 let mut events = events.fuse();
593 let mut message_id = None;
594 let mut first_item_text = None;
595 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
596
597 if let Some(first_event) = events.next().await {
598 match first_event {
599 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
600 message_id = Some(id);
601 }
602 Ok(LanguageModelCompletionEvent::Text(text)) => {
603 first_item_text = Some(text);
604 }
605 _ => (),
606 }
607 }
608
609 let stream = futures::stream::iter(first_item_text.map(Ok))
610 .chain(events.filter_map({
611 let last_token_usage = last_token_usage.clone();
612 move |result| {
613 let last_token_usage = last_token_usage.clone();
614 async move {
615 match result {
616 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
617 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
618 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
619 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
620 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
621 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
622 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
623 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
624 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
625 ..
626 }) => None,
627 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
628 *last_token_usage.lock() = token_usage;
629 None
630 }
631 Err(err) => Some(Err(err)),
632 }
633 }
634 }
635 }))
636 .boxed();
637
638 Ok(LanguageModelTextStream {
639 message_id,
640 stream,
641 last_token_usage,
642 })
643 }
644 .boxed()
645 }
646
647 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
648 None
649 }
650
651 #[cfg(any(test, feature = "test-support"))]
652 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
653 unimplemented!()
654 }
655}
656
657pub trait LanguageModelExt: LanguageModel {
658 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
659 match mode {
660 CompletionMode::Normal => self.max_token_count(),
661 CompletionMode::Max => self
662 .max_token_count_in_burn_mode()
663 .unwrap_or_else(|| self.max_token_count()),
664 }
665 }
666}
667impl LanguageModelExt for dyn LanguageModel {}
668
669/// An error that occurred when trying to authenticate the language model provider.
670#[derive(Debug, Error)]
671pub enum AuthenticateError {
672 #[error("connection refused")]
673 ConnectionRefused,
674 #[error("credentials not found")]
675 CredentialsNotFound,
676 #[error(transparent)]
677 Other(#[from] anyhow::Error),
678}
679
680pub trait LanguageModelProvider: 'static {
681 fn id(&self) -> LanguageModelProviderId;
682 fn name(&self) -> LanguageModelProviderName;
683 fn icon(&self) -> IconName {
684 IconName::ZedAssistant
685 }
686 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
687 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
688 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
689 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
690 Vec::new()
691 }
692 fn is_authenticated(&self, cx: &App) -> bool;
693 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
694 fn configuration_view(
695 &self,
696 target_agent: ConfigurationViewTargetAgent,
697 window: &mut Window,
698 cx: &mut App,
699 ) -> AnyView;
700 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
701}
702
703#[derive(Default, Clone)]
704pub enum ConfigurationViewTargetAgent {
705 #[default]
706 ZedAgent,
707 Other(SharedString),
708}
709
710#[derive(PartialEq, Eq)]
711pub enum LanguageModelProviderTosView {
712 /// When there are some past interactions in the Agent Panel.
713 ThreadEmptyState,
714 /// When there are no past interactions in the Agent Panel.
715 ThreadFreshStart,
716 TextThreadPopup,
717 Configuration,
718}
719
720pub trait LanguageModelProviderState: 'static {
721 type ObservableEntity;
722
723 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
724
725 fn subscribe<T: 'static>(
726 &self,
727 cx: &mut gpui::Context<T>,
728 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
729 ) -> Option<gpui::Subscription> {
730 let entity = self.observable_entity()?;
731 Some(cx.observe(&entity, move |this, _, cx| {
732 callback(this, cx);
733 }))
734 }
735}
736
737#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
738pub struct LanguageModelId(pub SharedString);
739
740#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
741pub struct LanguageModelName(pub SharedString);
742
743#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
744pub struct LanguageModelProviderId(pub SharedString);
745
746#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
747pub struct LanguageModelProviderName(pub SharedString);
748
749impl LanguageModelProviderId {
750 pub const fn new(id: &'static str) -> Self {
751 Self(SharedString::new_static(id))
752 }
753}
754
755impl LanguageModelProviderName {
756 pub const fn new(id: &'static str) -> Self {
757 Self(SharedString::new_static(id))
758 }
759}
760
761impl fmt::Display for LanguageModelProviderId {
762 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
763 write!(f, "{}", self.0)
764 }
765}
766
767impl fmt::Display for LanguageModelProviderName {
768 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
769 write!(f, "{}", self.0)
770 }
771}
772
773impl From<String> for LanguageModelId {
774 fn from(value: String) -> Self {
775 Self(SharedString::from(value))
776 }
777}
778
779impl From<String> for LanguageModelName {
780 fn from(value: String) -> Self {
781 Self(SharedString::from(value))
782 }
783}
784
785impl From<String> for LanguageModelProviderId {
786 fn from(value: String) -> Self {
787 Self(SharedString::from(value))
788 }
789}
790
791impl From<String> for LanguageModelProviderName {
792 fn from(value: String) -> Self {
793 Self(SharedString::from(value))
794 }
795}
796
797impl From<Arc<str>> for LanguageModelProviderId {
798 fn from(value: Arc<str>) -> Self {
799 Self(SharedString::from(value))
800 }
801}
802
803impl From<Arc<str>> for LanguageModelProviderName {
804 fn from(value: Arc<str>) -> Self {
805 Self(SharedString::from(value))
806 }
807}
808
809#[cfg(test)]
810mod tests {
811 use super::*;
812
813 #[test]
814 fn test_from_cloud_failure_with_upstream_http_error() {
815 let error = LanguageModelCompletionError::from_cloud_failure(
816 String::from("anthropic").into(),
817 "upstream_http_error".to_string(),
818 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
819 None,
820 );
821
822 match error {
823 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
824 assert_eq!(provider.0, "anthropic");
825 }
826 _ => panic!(
827 "Expected ServerOverloaded error for 503 status, got: {:?}",
828 error
829 ),
830 }
831
832 let error = LanguageModelCompletionError::from_cloud_failure(
833 String::from("anthropic").into(),
834 "upstream_http_error".to_string(),
835 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
836 None,
837 );
838
839 match error {
840 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
841 assert_eq!(provider.0, "anthropic");
842 assert_eq!(message, "Internal server error");
843 }
844 _ => panic!(
845 "Expected ApiInternalServerError for 500 status, got: {:?}",
846 error
847 ),
848 }
849 }
850
851 #[test]
852 fn test_from_cloud_failure_with_standard_format() {
853 let error = LanguageModelCompletionError::from_cloud_failure(
854 String::from("anthropic").into(),
855 "upstream_http_503".to_string(),
856 "Service unavailable".to_string(),
857 None,
858 );
859
860 match error {
861 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
862 assert_eq!(provider.0, "anthropic");
863 }
864 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
865 }
866 }
867
868 #[test]
869 fn test_upstream_http_error_connection_timeout() {
870 let error = LanguageModelCompletionError::from_cloud_failure(
871 String::from("anthropic").into(),
872 "upstream_http_error".to_string(),
873 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
874 None,
875 );
876
877 match error {
878 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
879 assert_eq!(provider.0, "anthropic");
880 }
881 _ => panic!(
882 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
883 error
884 ),
885 }
886
887 let error = LanguageModelCompletionError::from_cloud_failure(
888 String::from("anthropic").into(),
889 "upstream_http_error".to_string(),
890 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
891 None,
892 );
893
894 match error {
895 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
896 assert_eq!(provider.0, "anthropic");
897 assert_eq!(
898 message,
899 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
900 );
901 }
902 _ => panic!(
903 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
904 error
905 ),
906 }
907 }
908
909 #[test]
910 fn test_language_model_tool_use_serializes_with_signature() {
911 use serde_json::json;
912
913 let tool_use = LanguageModelToolUse {
914 id: LanguageModelToolUseId::from("test_id"),
915 name: "test_tool".into(),
916 raw_input: json!({"arg": "value"}).to_string(),
917 input: json!({"arg": "value"}),
918 is_input_complete: true,
919 thought_signature: Some("test_signature".to_string()),
920 };
921
922 let serialized = serde_json::to_value(&tool_use).unwrap();
923
924 assert_eq!(serialized["id"], "test_id");
925 assert_eq!(serialized["name"], "test_tool");
926 assert_eq!(serialized["thought_signature"], "test_signature");
927 }
928
929 #[test]
930 fn test_language_model_tool_use_deserializes_with_missing_signature() {
931 use serde_json::json;
932
933 let json = json!({
934 "id": "test_id",
935 "name": "test_tool",
936 "raw_input": "{\"arg\":\"value\"}",
937 "input": {"arg": "value"},
938 "is_input_complete": true
939 });
940
941 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
942
943 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
944 assert_eq!(tool_use.name.as_ref(), "test_tool");
945 assert_eq!(tool_use.thought_signature, None);
946 }
947
948 #[test]
949 fn test_language_model_tool_use_round_trip_with_signature() {
950 use serde_json::json;
951
952 let original = LanguageModelToolUse {
953 id: LanguageModelToolUseId::from("round_trip_id"),
954 name: "round_trip_tool".into(),
955 raw_input: json!({"key": "value"}).to_string(),
956 input: json!({"key": "value"}),
957 is_input_complete: true,
958 thought_signature: Some("round_trip_sig".to_string()),
959 };
960
961 let serialized = serde_json::to_value(&original).unwrap();
962 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
963
964 assert_eq!(deserialized.id, original.id);
965 assert_eq!(deserialized.name, original.name);
966 assert_eq!(deserialized.thought_signature, original.thought_signature);
967 }
968
969 #[test]
970 fn test_language_model_tool_use_round_trip_without_signature() {
971 use serde_json::json;
972
973 let original = LanguageModelToolUse {
974 id: LanguageModelToolUseId::from("no_sig_id"),
975 name: "no_sig_tool".into(),
976 raw_input: json!({"arg": "value"}).to_string(),
977 input: json!({"arg": "value"}),
978 is_input_complete: true,
979 thought_signature: None,
980 };
981
982 let serialized = serde_json::to_value(&original).unwrap();
983 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
984
985 assert_eq!(deserialized.id, original.id);
986 assert_eq!(deserialized.name, original.name);
987 assert_eq!(deserialized.thought_signature, None);
988 }
989}