1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::{Result, anyhow};
13use client::Client;
14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
15use futures::FutureExt;
16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
18use http_client::{StatusCode, http};
19use icons::IconName;
20use open_router::OpenRouterError;
21use parking_lot::Mutex;
22use serde::{Deserialize, Serialize};
23pub use settings::LanguageModelCacheConfiguration;
24use std::ops::{Add, Sub};
25use std::str::FromStr;
26use std::sync::Arc;
27use std::time::Duration;
28use std::{fmt, io};
29use thiserror::Error;
30use util::serde::is_default;
31
32pub use crate::model::*;
33pub use crate::rate_limiter::*;
34pub use crate::registry::*;
35pub use crate::request::*;
36pub use crate::role::*;
37pub use crate::telemetry::*;
38
39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
40 LanguageModelProviderId::new("anthropic");
41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
42 LanguageModelProviderName::new("Anthropic");
43
44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
46 LanguageModelProviderName::new("Google AI");
47
48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
50 LanguageModelProviderName::new("OpenAI");
51
52pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
53pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
54
55pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
56pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
57 LanguageModelProviderName::new("Zed");
58
59pub fn init(client: Arc<Client>, cx: &mut App) {
60 init_settings(cx);
61 RefreshLlmTokenListener::register(client, cx);
62}
63
64pub fn init_settings(cx: &mut App) {
65 registry::init(cx);
66}
67
68/// A completion event from a language model.
69#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
70pub enum LanguageModelCompletionEvent {
71 StatusUpdate(CompletionRequestStatus),
72 Stop(StopReason),
73 Text(String),
74 Thinking {
75 text: String,
76 signature: Option<String>,
77 },
78 RedactedThinking {
79 data: String,
80 },
81 ToolUse(LanguageModelToolUse),
82 ToolUseJsonParseError {
83 id: LanguageModelToolUseId,
84 tool_name: Arc<str>,
85 raw_input: Arc<str>,
86 json_parse_error: String,
87 },
88 StartMessage {
89 message_id: String,
90 },
91 UsageUpdate(TokenUsage),
92}
93
94#[derive(Error, Debug)]
95pub enum LanguageModelCompletionError {
96 #[error("prompt too large for context window")]
97 PromptTooLarge { tokens: Option<u64> },
98 #[error("missing {provider} API key")]
99 NoApiKey { provider: LanguageModelProviderName },
100 #[error("{provider}'s API rate limit exceeded")]
101 RateLimitExceeded {
102 provider: LanguageModelProviderName,
103 retry_after: Option<Duration>,
104 },
105 #[error("{provider}'s API servers are overloaded right now")]
106 ServerOverloaded {
107 provider: LanguageModelProviderName,
108 retry_after: Option<Duration>,
109 },
110 #[error("{provider}'s API server reported an internal server error: {message}")]
111 ApiInternalServerError {
112 provider: LanguageModelProviderName,
113 message: String,
114 },
115 #[error("{message}")]
116 UpstreamProviderError {
117 message: String,
118 status: StatusCode,
119 retry_after: Option<Duration>,
120 },
121 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
122 HttpResponseError {
123 provider: LanguageModelProviderName,
124 status_code: StatusCode,
125 message: String,
126 },
127
128 // Client errors
129 #[error("invalid request format to {provider}'s API: {message}")]
130 BadRequestFormat {
131 provider: LanguageModelProviderName,
132 message: String,
133 },
134 #[error("authentication error with {provider}'s API: {message}")]
135 AuthenticationError {
136 provider: LanguageModelProviderName,
137 message: String,
138 },
139 #[error("permission error with {provider}'s API: {message}")]
140 PermissionError {
141 provider: LanguageModelProviderName,
142 message: String,
143 },
144 #[error("language model provider API endpoint not found")]
145 ApiEndpointNotFound { provider: LanguageModelProviderName },
146 #[error("I/O error reading response from {provider}'s API")]
147 ApiReadResponseError {
148 provider: LanguageModelProviderName,
149 #[source]
150 error: io::Error,
151 },
152 #[error("error serializing request to {provider} API")]
153 SerializeRequest {
154 provider: LanguageModelProviderName,
155 #[source]
156 error: serde_json::Error,
157 },
158 #[error("error building request body to {provider} API")]
159 BuildRequestBody {
160 provider: LanguageModelProviderName,
161 #[source]
162 error: http::Error,
163 },
164 #[error("error sending HTTP request to {provider} API")]
165 HttpSend {
166 provider: LanguageModelProviderName,
167 #[source]
168 error: anyhow::Error,
169 },
170 #[error("error deserializing {provider} API response")]
171 DeserializeResponse {
172 provider: LanguageModelProviderName,
173 #[source]
174 error: serde_json::Error,
175 },
176
177 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
178 #[error(transparent)]
179 Other(#[from] anyhow::Error),
180}
181
182impl LanguageModelCompletionError {
183 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
184 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
185 let upstream_status = error_json
186 .get("upstream_status")
187 .and_then(|v| v.as_u64())
188 .and_then(|status| u16::try_from(status).ok())
189 .and_then(|status| StatusCode::from_u16(status).ok())?;
190 let inner_message = error_json
191 .get("message")
192 .and_then(|v| v.as_str())
193 .unwrap_or(message)
194 .to_string();
195 Some((upstream_status, inner_message))
196 }
197
198 pub fn from_cloud_failure(
199 upstream_provider: LanguageModelProviderName,
200 code: String,
201 message: String,
202 retry_after: Option<Duration>,
203 ) -> Self {
204 if let Some(tokens) = parse_prompt_too_long(&message) {
205 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
206 // to be reported. This is a temporary workaround to handle this in the case where the
207 // token limit has been exceeded.
208 Self::PromptTooLarge {
209 tokens: Some(tokens),
210 }
211 } else if code == "upstream_http_error" {
212 if let Some((upstream_status, inner_message)) =
213 Self::parse_upstream_error_json(&message)
214 {
215 return Self::from_http_status(
216 upstream_provider,
217 upstream_status,
218 inner_message,
219 retry_after,
220 );
221 }
222 anyhow!("completion request failed, code: {code}, message: {message}").into()
223 } else if let Some(status_code) = code
224 .strip_prefix("upstream_http_")
225 .and_then(|code| StatusCode::from_str(code).ok())
226 {
227 Self::from_http_status(upstream_provider, status_code, message, retry_after)
228 } else if let Some(status_code) = code
229 .strip_prefix("http_")
230 .and_then(|code| StatusCode::from_str(code).ok())
231 {
232 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
233 } else {
234 anyhow!("completion request failed, code: {code}, message: {message}").into()
235 }
236 }
237
238 pub fn from_http_status(
239 provider: LanguageModelProviderName,
240 status_code: StatusCode,
241 message: String,
242 retry_after: Option<Duration>,
243 ) -> Self {
244 match status_code {
245 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
246 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
247 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
248 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
249 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
250 tokens: parse_prompt_too_long(&message),
251 },
252 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
253 provider,
254 retry_after,
255 },
256 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
257 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
258 provider,
259 retry_after,
260 },
261 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
262 provider,
263 retry_after,
264 },
265 _ => Self::HttpResponseError {
266 provider,
267 status_code,
268 message,
269 },
270 }
271 }
272}
273
274impl From<AnthropicError> for LanguageModelCompletionError {
275 fn from(error: AnthropicError) -> Self {
276 let provider = ANTHROPIC_PROVIDER_NAME;
277 match error {
278 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
279 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
280 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
281 AnthropicError::DeserializeResponse(error) => {
282 Self::DeserializeResponse { provider, error }
283 }
284 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
285 AnthropicError::HttpResponseError {
286 status_code,
287 message,
288 } => Self::HttpResponseError {
289 provider,
290 status_code,
291 message,
292 },
293 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
294 provider,
295 retry_after: Some(retry_after),
296 },
297 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
298 provider,
299 retry_after,
300 },
301 AnthropicError::ApiError(api_error) => api_error.into(),
302 }
303 }
304}
305
306impl From<anthropic::ApiError> for LanguageModelCompletionError {
307 fn from(error: anthropic::ApiError) -> Self {
308 use anthropic::ApiErrorCode::*;
309 let provider = ANTHROPIC_PROVIDER_NAME;
310 match error.code() {
311 Some(code) => match code {
312 InvalidRequestError => Self::BadRequestFormat {
313 provider,
314 message: error.message,
315 },
316 AuthenticationError => Self::AuthenticationError {
317 provider,
318 message: error.message,
319 },
320 PermissionError => Self::PermissionError {
321 provider,
322 message: error.message,
323 },
324 NotFoundError => Self::ApiEndpointNotFound { provider },
325 RequestTooLarge => Self::PromptTooLarge {
326 tokens: parse_prompt_too_long(&error.message),
327 },
328 RateLimitError => Self::RateLimitExceeded {
329 provider,
330 retry_after: None,
331 },
332 ApiError => Self::ApiInternalServerError {
333 provider,
334 message: error.message,
335 },
336 OverloadedError => Self::ServerOverloaded {
337 provider,
338 retry_after: None,
339 },
340 },
341 None => Self::Other(error.into()),
342 }
343 }
344}
345
346impl From<OpenRouterError> for LanguageModelCompletionError {
347 fn from(error: OpenRouterError) -> Self {
348 let provider = LanguageModelProviderName::new("OpenRouter");
349 match error {
350 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
351 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
352 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
353 OpenRouterError::DeserializeResponse(error) => {
354 Self::DeserializeResponse { provider, error }
355 }
356 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
357 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
358 provider,
359 retry_after: Some(retry_after),
360 },
361 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
362 provider,
363 retry_after,
364 },
365 OpenRouterError::ApiError(api_error) => api_error.into(),
366 }
367 }
368}
369
370impl From<open_router::ApiError> for LanguageModelCompletionError {
371 fn from(error: open_router::ApiError) -> Self {
372 use open_router::ApiErrorCode::*;
373 let provider = LanguageModelProviderName::new("OpenRouter");
374 match error.code {
375 InvalidRequestError => Self::BadRequestFormat {
376 provider,
377 message: error.message,
378 },
379 AuthenticationError => Self::AuthenticationError {
380 provider,
381 message: error.message,
382 },
383 PaymentRequiredError => Self::AuthenticationError {
384 provider,
385 message: format!("Payment required: {}", error.message),
386 },
387 PermissionError => Self::PermissionError {
388 provider,
389 message: error.message,
390 },
391 RequestTimedOut => Self::HttpResponseError {
392 provider,
393 status_code: StatusCode::REQUEST_TIMEOUT,
394 message: error.message,
395 },
396 RateLimitError => Self::RateLimitExceeded {
397 provider,
398 retry_after: None,
399 },
400 ApiError => Self::ApiInternalServerError {
401 provider,
402 message: error.message,
403 },
404 OverloadedError => Self::ServerOverloaded {
405 provider,
406 retry_after: None,
407 },
408 }
409 }
410}
411
412/// Indicates the format used to define the input schema for a language model tool.
413#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
414pub enum LanguageModelToolSchemaFormat {
415 /// A JSON schema, see https://json-schema.org
416 JsonSchema,
417 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
418 JsonSchemaSubset,
419}
420
421#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
422#[serde(rename_all = "snake_case")]
423pub enum StopReason {
424 EndTurn,
425 MaxTokens,
426 ToolUse,
427 Refusal,
428}
429
430#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
431pub struct TokenUsage {
432 #[serde(default, skip_serializing_if = "is_default")]
433 pub input_tokens: u64,
434 #[serde(default, skip_serializing_if = "is_default")]
435 pub output_tokens: u64,
436 #[serde(default, skip_serializing_if = "is_default")]
437 pub cache_creation_input_tokens: u64,
438 #[serde(default, skip_serializing_if = "is_default")]
439 pub cache_read_input_tokens: u64,
440}
441
442impl TokenUsage {
443 pub fn total_tokens(&self) -> u64 {
444 self.input_tokens
445 + self.output_tokens
446 + self.cache_read_input_tokens
447 + self.cache_creation_input_tokens
448 }
449}
450
451impl Add<TokenUsage> for TokenUsage {
452 type Output = Self;
453
454 fn add(self, other: Self) -> Self {
455 Self {
456 input_tokens: self.input_tokens + other.input_tokens,
457 output_tokens: self.output_tokens + other.output_tokens,
458 cache_creation_input_tokens: self.cache_creation_input_tokens
459 + other.cache_creation_input_tokens,
460 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
461 }
462 }
463}
464
465impl Sub<TokenUsage> for TokenUsage {
466 type Output = Self;
467
468 fn sub(self, other: Self) -> Self {
469 Self {
470 input_tokens: self.input_tokens - other.input_tokens,
471 output_tokens: self.output_tokens - other.output_tokens,
472 cache_creation_input_tokens: self.cache_creation_input_tokens
473 - other.cache_creation_input_tokens,
474 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
475 }
476 }
477}
478
479#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
480pub struct LanguageModelToolUseId(Arc<str>);
481
482impl fmt::Display for LanguageModelToolUseId {
483 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
484 write!(f, "{}", self.0)
485 }
486}
487
488impl<T> From<T> for LanguageModelToolUseId
489where
490 T: Into<Arc<str>>,
491{
492 fn from(value: T) -> Self {
493 Self(value.into())
494 }
495}
496
497#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
498pub struct LanguageModelToolUse {
499 pub id: LanguageModelToolUseId,
500 pub name: Arc<str>,
501 pub raw_input: String,
502 pub input: serde_json::Value,
503 pub is_input_complete: bool,
504 /// Thought signature the model sent us. Some models require that this
505 /// signature be preserved and sent back in conversation history for validation.
506 pub thought_signature: Option<String>,
507}
508
509pub struct LanguageModelTextStream {
510 pub message_id: Option<String>,
511 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
512 // Has complete token usage after the stream has finished
513 pub last_token_usage: Arc<Mutex<TokenUsage>>,
514}
515
516impl Default for LanguageModelTextStream {
517 fn default() -> Self {
518 Self {
519 message_id: None,
520 stream: Box::pin(futures::stream::empty()),
521 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
522 }
523 }
524}
525
526pub trait LanguageModel: Send + Sync {
527 fn id(&self) -> LanguageModelId;
528 fn name(&self) -> LanguageModelName;
529 fn provider_id(&self) -> LanguageModelProviderId;
530 fn provider_name(&self) -> LanguageModelProviderName;
531 fn upstream_provider_id(&self) -> LanguageModelProviderId {
532 self.provider_id()
533 }
534 fn upstream_provider_name(&self) -> LanguageModelProviderName {
535 self.provider_name()
536 }
537
538 fn telemetry_id(&self) -> String;
539
540 fn api_key(&self, _cx: &App) -> Option<String> {
541 None
542 }
543
544 /// Whether this model supports images
545 fn supports_images(&self) -> bool;
546
547 /// Whether this model supports tools.
548 fn supports_tools(&self) -> bool;
549
550 /// Whether this model supports choosing which tool to use.
551 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
552
553 /// Returns whether this model supports "burn mode";
554 fn supports_burn_mode(&self) -> bool {
555 false
556 }
557
558 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
559 LanguageModelToolSchemaFormat::JsonSchema
560 }
561
562 fn max_token_count(&self) -> u64;
563 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
564 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
565 None
566 }
567 fn max_output_tokens(&self) -> Option<u64> {
568 None
569 }
570
571 fn count_tokens(
572 &self,
573 request: LanguageModelRequest,
574 cx: &App,
575 ) -> BoxFuture<'static, Result<u64>>;
576
577 fn stream_completion(
578 &self,
579 request: LanguageModelRequest,
580 cx: &AsyncApp,
581 ) -> BoxFuture<
582 'static,
583 Result<
584 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
585 LanguageModelCompletionError,
586 >,
587 >;
588
589 fn stream_completion_text(
590 &self,
591 request: LanguageModelRequest,
592 cx: &AsyncApp,
593 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
594 let future = self.stream_completion(request, cx);
595
596 async move {
597 let events = future.await?;
598 let mut events = events.fuse();
599 let mut message_id = None;
600 let mut first_item_text = None;
601 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
602
603 if let Some(first_event) = events.next().await {
604 match first_event {
605 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
606 message_id = Some(id);
607 }
608 Ok(LanguageModelCompletionEvent::Text(text)) => {
609 first_item_text = Some(text);
610 }
611 _ => (),
612 }
613 }
614
615 let stream = futures::stream::iter(first_item_text.map(Ok))
616 .chain(events.filter_map({
617 let last_token_usage = last_token_usage.clone();
618 move |result| {
619 let last_token_usage = last_token_usage.clone();
620 async move {
621 match result {
622 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
623 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
624 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
625 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
626 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
627 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
628 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
629 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
630 ..
631 }) => None,
632 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
633 *last_token_usage.lock() = token_usage;
634 None
635 }
636 Err(err) => Some(Err(err)),
637 }
638 }
639 }
640 }))
641 .boxed();
642
643 Ok(LanguageModelTextStream {
644 message_id,
645 stream,
646 last_token_usage,
647 })
648 }
649 .boxed()
650 }
651
652 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
653 None
654 }
655
656 #[cfg(any(test, feature = "test-support"))]
657 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
658 unimplemented!()
659 }
660}
661
662pub trait LanguageModelExt: LanguageModel {
663 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
664 match mode {
665 CompletionMode::Normal => self.max_token_count(),
666 CompletionMode::Max => self
667 .max_token_count_in_burn_mode()
668 .unwrap_or_else(|| self.max_token_count()),
669 }
670 }
671}
672impl LanguageModelExt for dyn LanguageModel {}
673
674/// An error that occurred when trying to authenticate the language model provider.
675#[derive(Debug, Error)]
676pub enum AuthenticateError {
677 #[error("connection refused")]
678 ConnectionRefused,
679 #[error("credentials not found")]
680 CredentialsNotFound,
681 #[error(transparent)]
682 Other(#[from] anyhow::Error),
683}
684
685pub trait LanguageModelProvider: 'static {
686 fn id(&self) -> LanguageModelProviderId;
687 fn name(&self) -> LanguageModelProviderName;
688 fn icon(&self) -> IconName {
689 IconName::ZedAssistant
690 }
691 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
692 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
693 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
694 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
695 Vec::new()
696 }
697 fn is_authenticated(&self, cx: &App) -> bool;
698 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
699 fn configuration_view(
700 &self,
701 target_agent: ConfigurationViewTargetAgent,
702 window: &mut Window,
703 cx: &mut App,
704 ) -> AnyView;
705 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
706}
707
708#[derive(Default, Clone)]
709pub enum ConfigurationViewTargetAgent {
710 #[default]
711 ZedAgent,
712 Other(SharedString),
713}
714
715#[derive(PartialEq, Eq)]
716pub enum LanguageModelProviderTosView {
717 /// When there are some past interactions in the Agent Panel.
718 ThreadEmptyState,
719 /// When there are no past interactions in the Agent Panel.
720 ThreadFreshStart,
721 TextThreadPopup,
722 Configuration,
723}
724
725pub trait LanguageModelProviderState: 'static {
726 type ObservableEntity;
727
728 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
729
730 fn subscribe<T: 'static>(
731 &self,
732 cx: &mut gpui::Context<T>,
733 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
734 ) -> Option<gpui::Subscription> {
735 let entity = self.observable_entity()?;
736 Some(cx.observe(&entity, move |this, _, cx| {
737 callback(this, cx);
738 }))
739 }
740}
741
742#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
743pub struct LanguageModelId(pub SharedString);
744
745#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
746pub struct LanguageModelName(pub SharedString);
747
748#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
749pub struct LanguageModelProviderId(pub SharedString);
750
751#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
752pub struct LanguageModelProviderName(pub SharedString);
753
754impl LanguageModelProviderId {
755 pub const fn new(id: &'static str) -> Self {
756 Self(SharedString::new_static(id))
757 }
758}
759
760impl LanguageModelProviderName {
761 pub const fn new(id: &'static str) -> Self {
762 Self(SharedString::new_static(id))
763 }
764}
765
766impl fmt::Display for LanguageModelProviderId {
767 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
768 write!(f, "{}", self.0)
769 }
770}
771
772impl fmt::Display for LanguageModelProviderName {
773 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
774 write!(f, "{}", self.0)
775 }
776}
777
778impl From<String> for LanguageModelId {
779 fn from(value: String) -> Self {
780 Self(SharedString::from(value))
781 }
782}
783
784impl From<String> for LanguageModelName {
785 fn from(value: String) -> Self {
786 Self(SharedString::from(value))
787 }
788}
789
790impl From<String> for LanguageModelProviderId {
791 fn from(value: String) -> Self {
792 Self(SharedString::from(value))
793 }
794}
795
796impl From<String> for LanguageModelProviderName {
797 fn from(value: String) -> Self {
798 Self(SharedString::from(value))
799 }
800}
801
802impl From<Arc<str>> for LanguageModelProviderId {
803 fn from(value: Arc<str>) -> Self {
804 Self(SharedString::from(value))
805 }
806}
807
808impl From<Arc<str>> for LanguageModelProviderName {
809 fn from(value: Arc<str>) -> Self {
810 Self(SharedString::from(value))
811 }
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817
818 #[test]
819 fn test_from_cloud_failure_with_upstream_http_error() {
820 let error = LanguageModelCompletionError::from_cloud_failure(
821 String::from("anthropic").into(),
822 "upstream_http_error".to_string(),
823 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
824 None,
825 );
826
827 match error {
828 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
829 assert_eq!(provider.0, "anthropic");
830 }
831 _ => panic!(
832 "Expected ServerOverloaded error for 503 status, got: {:?}",
833 error
834 ),
835 }
836
837 let error = LanguageModelCompletionError::from_cloud_failure(
838 String::from("anthropic").into(),
839 "upstream_http_error".to_string(),
840 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
841 None,
842 );
843
844 match error {
845 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
846 assert_eq!(provider.0, "anthropic");
847 assert_eq!(message, "Internal server error");
848 }
849 _ => panic!(
850 "Expected ApiInternalServerError for 500 status, got: {:?}",
851 error
852 ),
853 }
854 }
855
856 #[test]
857 fn test_from_cloud_failure_with_standard_format() {
858 let error = LanguageModelCompletionError::from_cloud_failure(
859 String::from("anthropic").into(),
860 "upstream_http_503".to_string(),
861 "Service unavailable".to_string(),
862 None,
863 );
864
865 match error {
866 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
867 assert_eq!(provider.0, "anthropic");
868 }
869 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
870 }
871 }
872
873 #[test]
874 fn test_upstream_http_error_connection_timeout() {
875 let error = LanguageModelCompletionError::from_cloud_failure(
876 String::from("anthropic").into(),
877 "upstream_http_error".to_string(),
878 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
879 None,
880 );
881
882 match error {
883 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
884 assert_eq!(provider.0, "anthropic");
885 }
886 _ => panic!(
887 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
888 error
889 ),
890 }
891
892 let error = LanguageModelCompletionError::from_cloud_failure(
893 String::from("anthropic").into(),
894 "upstream_http_error".to_string(),
895 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
896 None,
897 );
898
899 match error {
900 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
901 assert_eq!(provider.0, "anthropic");
902 assert_eq!(
903 message,
904 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
905 );
906 }
907 _ => panic!(
908 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
909 error
910 ),
911 }
912 }
913
914 #[test]
915 fn test_language_model_tool_use_serializes_with_signature() {
916 use serde_json::json;
917
918 let tool_use = LanguageModelToolUse {
919 id: LanguageModelToolUseId::from("test_id"),
920 name: "test_tool".into(),
921 raw_input: json!({"arg": "value"}).to_string(),
922 input: json!({"arg": "value"}),
923 is_input_complete: true,
924 thought_signature: Some("test_signature".to_string()),
925 };
926
927 let serialized = serde_json::to_value(&tool_use).unwrap();
928
929 assert_eq!(serialized["id"], "test_id");
930 assert_eq!(serialized["name"], "test_tool");
931 assert_eq!(serialized["thought_signature"], "test_signature");
932 }
933
934 #[test]
935 fn test_language_model_tool_use_deserializes_with_missing_signature() {
936 use serde_json::json;
937
938 let json = json!({
939 "id": "test_id",
940 "name": "test_tool",
941 "raw_input": "{\"arg\":\"value\"}",
942 "input": {"arg": "value"},
943 "is_input_complete": true
944 });
945
946 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
947
948 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
949 assert_eq!(tool_use.name.as_ref(), "test_tool");
950 assert_eq!(tool_use.thought_signature, None);
951 }
952
953 #[test]
954 fn test_language_model_tool_use_round_trip_with_signature() {
955 use serde_json::json;
956
957 let original = LanguageModelToolUse {
958 id: LanguageModelToolUseId::from("round_trip_id"),
959 name: "round_trip_tool".into(),
960 raw_input: json!({"key": "value"}).to_string(),
961 input: json!({"key": "value"}),
962 is_input_complete: true,
963 thought_signature: Some("round_trip_sig".to_string()),
964 };
965
966 let serialized = serde_json::to_value(&original).unwrap();
967 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
968
969 assert_eq!(deserialized.id, original.id);
970 assert_eq!(deserialized.name, original.name);
971 assert_eq!(deserialized.thought_signature, original.thought_signature);
972 }
973
974 #[test]
975 fn test_language_model_tool_use_round_trip_without_signature() {
976 use serde_json::json;
977
978 let original = LanguageModelToolUse {
979 id: LanguageModelToolUseId::from("no_sig_id"),
980 name: "no_sig_tool".into(),
981 raw_input: json!({"key": "value"}).to_string(),
982 input: json!({"key": "value"}),
983 is_input_complete: true,
984 thought_signature: None,
985 };
986
987 let serialized = serde_json::to_value(&original).unwrap();
988 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
989
990 assert_eq!(deserialized.id, original.id);
991 assert_eq!(deserialized.name, original.name);
992 assert_eq!(deserialized.thought_signature, None);
993 }
994}