1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::{Result, anyhow};
13use client::Client;
14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
15use futures::FutureExt;
16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
18use http_client::{StatusCode, http};
19use icons::IconName;
20use open_router::OpenRouterError;
21use parking_lot::Mutex;
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize, de::DeserializeOwned};
24pub use settings::LanguageModelCacheConfiguration;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39
40pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
41 LanguageModelProviderId::new("anthropic");
42pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
43 LanguageModelProviderName::new("Anthropic");
44
45pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
46pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
47 LanguageModelProviderName::new("Google AI");
48
49pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
50pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
51 LanguageModelProviderName::new("OpenAI");
52
53pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
54pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
55 LanguageModelProviderName::new("Zed");
56
57pub fn init(client: Arc<Client>, cx: &mut App) {
58 init_settings(cx);
59 RefreshLlmTokenListener::register(client, cx);
60}
61
62pub fn init_settings(cx: &mut App) {
63 registry::init(cx);
64}
65
66/// A completion event from a language model.
67#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
68pub enum LanguageModelCompletionEvent {
69 StatusUpdate(CompletionRequestStatus),
70 Stop(StopReason),
71 Text(String),
72 Thinking {
73 text: String,
74 signature: Option<String>,
75 },
76 RedactedThinking {
77 data: String,
78 },
79 ToolUse(LanguageModelToolUse),
80 ToolUseJsonParseError {
81 id: LanguageModelToolUseId,
82 tool_name: Arc<str>,
83 raw_input: Arc<str>,
84 json_parse_error: String,
85 },
86 StartMessage {
87 message_id: String,
88 },
89 UsageUpdate(TokenUsage),
90}
91
92#[derive(Error, Debug)]
93pub enum LanguageModelCompletionError {
94 #[error("prompt too large for context window")]
95 PromptTooLarge { tokens: Option<u64> },
96 #[error("missing {provider} API key")]
97 NoApiKey { provider: LanguageModelProviderName },
98 #[error("{provider}'s API rate limit exceeded")]
99 RateLimitExceeded {
100 provider: LanguageModelProviderName,
101 retry_after: Option<Duration>,
102 },
103 #[error("{provider}'s API servers are overloaded right now")]
104 ServerOverloaded {
105 provider: LanguageModelProviderName,
106 retry_after: Option<Duration>,
107 },
108 #[error("{provider}'s API server reported an internal server error: {message}")]
109 ApiInternalServerError {
110 provider: LanguageModelProviderName,
111 message: String,
112 },
113 #[error("{message}")]
114 UpstreamProviderError {
115 message: String,
116 status: StatusCode,
117 retry_after: Option<Duration>,
118 },
119 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
120 HttpResponseError {
121 provider: LanguageModelProviderName,
122 status_code: StatusCode,
123 message: String,
124 },
125
126 // Client errors
127 #[error("invalid request format to {provider}'s API: {message}")]
128 BadRequestFormat {
129 provider: LanguageModelProviderName,
130 message: String,
131 },
132 #[error("authentication error with {provider}'s API: {message}")]
133 AuthenticationError {
134 provider: LanguageModelProviderName,
135 message: String,
136 },
137 #[error("permission error with {provider}'s API: {message}")]
138 PermissionError {
139 provider: LanguageModelProviderName,
140 message: String,
141 },
142 #[error("language model provider API endpoint not found")]
143 ApiEndpointNotFound { provider: LanguageModelProviderName },
144 #[error("I/O error reading response from {provider}'s API")]
145 ApiReadResponseError {
146 provider: LanguageModelProviderName,
147 #[source]
148 error: io::Error,
149 },
150 #[error("error serializing request to {provider} API")]
151 SerializeRequest {
152 provider: LanguageModelProviderName,
153 #[source]
154 error: serde_json::Error,
155 },
156 #[error("error building request body to {provider} API")]
157 BuildRequestBody {
158 provider: LanguageModelProviderName,
159 #[source]
160 error: http::Error,
161 },
162 #[error("error sending HTTP request to {provider} API")]
163 HttpSend {
164 provider: LanguageModelProviderName,
165 #[source]
166 error: anyhow::Error,
167 },
168 #[error("error deserializing {provider} API response")]
169 DeserializeResponse {
170 provider: LanguageModelProviderName,
171 #[source]
172 error: serde_json::Error,
173 },
174
175 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
176 #[error(transparent)]
177 Other(#[from] anyhow::Error),
178}
179
180impl LanguageModelCompletionError {
181 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
182 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
183 let upstream_status = error_json
184 .get("upstream_status")
185 .and_then(|v| v.as_u64())
186 .and_then(|status| u16::try_from(status).ok())
187 .and_then(|status| StatusCode::from_u16(status).ok())?;
188 let inner_message = error_json
189 .get("message")
190 .and_then(|v| v.as_str())
191 .unwrap_or(message)
192 .to_string();
193 Some((upstream_status, inner_message))
194 }
195
196 pub fn from_cloud_failure(
197 upstream_provider: LanguageModelProviderName,
198 code: String,
199 message: String,
200 retry_after: Option<Duration>,
201 ) -> Self {
202 if let Some(tokens) = parse_prompt_too_long(&message) {
203 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
204 // to be reported. This is a temporary workaround to handle this in the case where the
205 // token limit has been exceeded.
206 Self::PromptTooLarge {
207 tokens: Some(tokens),
208 }
209 } else if code == "upstream_http_error" {
210 if let Some((upstream_status, inner_message)) =
211 Self::parse_upstream_error_json(&message)
212 {
213 return Self::from_http_status(
214 upstream_provider,
215 upstream_status,
216 inner_message,
217 retry_after,
218 );
219 }
220 anyhow!("completion request failed, code: {code}, message: {message}").into()
221 } else if let Some(status_code) = code
222 .strip_prefix("upstream_http_")
223 .and_then(|code| StatusCode::from_str(code).ok())
224 {
225 Self::from_http_status(upstream_provider, status_code, message, retry_after)
226 } else if let Some(status_code) = code
227 .strip_prefix("http_")
228 .and_then(|code| StatusCode::from_str(code).ok())
229 {
230 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
231 } else {
232 anyhow!("completion request failed, code: {code}, message: {message}").into()
233 }
234 }
235
236 pub fn from_http_status(
237 provider: LanguageModelProviderName,
238 status_code: StatusCode,
239 message: String,
240 retry_after: Option<Duration>,
241 ) -> Self {
242 match status_code {
243 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
244 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
245 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
246 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
247 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
248 tokens: parse_prompt_too_long(&message),
249 },
250 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
251 provider,
252 retry_after,
253 },
254 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
255 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
256 provider,
257 retry_after,
258 },
259 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
260 provider,
261 retry_after,
262 },
263 _ => Self::HttpResponseError {
264 provider,
265 status_code,
266 message,
267 },
268 }
269 }
270}
271
272impl From<AnthropicError> for LanguageModelCompletionError {
273 fn from(error: AnthropicError) -> Self {
274 let provider = ANTHROPIC_PROVIDER_NAME;
275 match error {
276 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
277 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
278 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
279 AnthropicError::DeserializeResponse(error) => {
280 Self::DeserializeResponse { provider, error }
281 }
282 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
283 AnthropicError::HttpResponseError {
284 status_code,
285 message,
286 } => Self::HttpResponseError {
287 provider,
288 status_code,
289 message,
290 },
291 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
292 provider,
293 retry_after: Some(retry_after),
294 },
295 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
296 provider,
297 retry_after,
298 },
299 AnthropicError::ApiError(api_error) => api_error.into(),
300 }
301 }
302}
303
304impl From<anthropic::ApiError> for LanguageModelCompletionError {
305 fn from(error: anthropic::ApiError) -> Self {
306 use anthropic::ApiErrorCode::*;
307 let provider = ANTHROPIC_PROVIDER_NAME;
308 match error.code() {
309 Some(code) => match code {
310 InvalidRequestError => Self::BadRequestFormat {
311 provider,
312 message: error.message,
313 },
314 AuthenticationError => Self::AuthenticationError {
315 provider,
316 message: error.message,
317 },
318 PermissionError => Self::PermissionError {
319 provider,
320 message: error.message,
321 },
322 NotFoundError => Self::ApiEndpointNotFound { provider },
323 RequestTooLarge => Self::PromptTooLarge {
324 tokens: parse_prompt_too_long(&error.message),
325 },
326 RateLimitError => Self::RateLimitExceeded {
327 provider,
328 retry_after: None,
329 },
330 ApiError => Self::ApiInternalServerError {
331 provider,
332 message: error.message,
333 },
334 OverloadedError => Self::ServerOverloaded {
335 provider,
336 retry_after: None,
337 },
338 },
339 None => Self::Other(error.into()),
340 }
341 }
342}
343
344impl From<OpenRouterError> for LanguageModelCompletionError {
345 fn from(error: OpenRouterError) -> Self {
346 let provider = LanguageModelProviderName::new("OpenRouter");
347 match error {
348 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
349 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
350 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
351 OpenRouterError::DeserializeResponse(error) => {
352 Self::DeserializeResponse { provider, error }
353 }
354 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
355 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
356 provider,
357 retry_after: Some(retry_after),
358 },
359 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
360 provider,
361 retry_after,
362 },
363 OpenRouterError::ApiError(api_error) => api_error.into(),
364 }
365 }
366}
367
368impl From<open_router::ApiError> for LanguageModelCompletionError {
369 fn from(error: open_router::ApiError) -> Self {
370 use open_router::ApiErrorCode::*;
371 let provider = LanguageModelProviderName::new("OpenRouter");
372 match error.code {
373 InvalidRequestError => Self::BadRequestFormat {
374 provider,
375 message: error.message,
376 },
377 AuthenticationError => Self::AuthenticationError {
378 provider,
379 message: error.message,
380 },
381 PaymentRequiredError => Self::AuthenticationError {
382 provider,
383 message: format!("Payment required: {}", error.message),
384 },
385 PermissionError => Self::PermissionError {
386 provider,
387 message: error.message,
388 },
389 RequestTimedOut => Self::HttpResponseError {
390 provider,
391 status_code: StatusCode::REQUEST_TIMEOUT,
392 message: error.message,
393 },
394 RateLimitError => Self::RateLimitExceeded {
395 provider,
396 retry_after: None,
397 },
398 ApiError => Self::ApiInternalServerError {
399 provider,
400 message: error.message,
401 },
402 OverloadedError => Self::ServerOverloaded {
403 provider,
404 retry_after: None,
405 },
406 }
407 }
408}
409
410/// Indicates the format used to define the input schema for a language model tool.
411#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
412pub enum LanguageModelToolSchemaFormat {
413 /// A JSON schema, see https://json-schema.org
414 JsonSchema,
415 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
416 JsonSchemaSubset,
417}
418
419#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
420#[serde(rename_all = "snake_case")]
421pub enum StopReason {
422 EndTurn,
423 MaxTokens,
424 ToolUse,
425 Refusal,
426}
427
428#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
429pub struct TokenUsage {
430 #[serde(default, skip_serializing_if = "is_default")]
431 pub input_tokens: u64,
432 #[serde(default, skip_serializing_if = "is_default")]
433 pub output_tokens: u64,
434 #[serde(default, skip_serializing_if = "is_default")]
435 pub cache_creation_input_tokens: u64,
436 #[serde(default, skip_serializing_if = "is_default")]
437 pub cache_read_input_tokens: u64,
438}
439
440impl TokenUsage {
441 pub fn total_tokens(&self) -> u64 {
442 self.input_tokens
443 + self.output_tokens
444 + self.cache_read_input_tokens
445 + self.cache_creation_input_tokens
446 }
447}
448
449impl Add<TokenUsage> for TokenUsage {
450 type Output = Self;
451
452 fn add(self, other: Self) -> Self {
453 Self {
454 input_tokens: self.input_tokens + other.input_tokens,
455 output_tokens: self.output_tokens + other.output_tokens,
456 cache_creation_input_tokens: self.cache_creation_input_tokens
457 + other.cache_creation_input_tokens,
458 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
459 }
460 }
461}
462
463impl Sub<TokenUsage> for TokenUsage {
464 type Output = Self;
465
466 fn sub(self, other: Self) -> Self {
467 Self {
468 input_tokens: self.input_tokens - other.input_tokens,
469 output_tokens: self.output_tokens - other.output_tokens,
470 cache_creation_input_tokens: self.cache_creation_input_tokens
471 - other.cache_creation_input_tokens,
472 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
473 }
474 }
475}
476
477#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
478pub struct LanguageModelToolUseId(Arc<str>);
479
480impl fmt::Display for LanguageModelToolUseId {
481 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
482 write!(f, "{}", self.0)
483 }
484}
485
486impl<T> From<T> for LanguageModelToolUseId
487where
488 T: Into<Arc<str>>,
489{
490 fn from(value: T) -> Self {
491 Self(value.into())
492 }
493}
494
495#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
496pub struct LanguageModelToolUse {
497 pub id: LanguageModelToolUseId,
498 pub name: Arc<str>,
499 pub raw_input: String,
500 pub input: serde_json::Value,
501 pub is_input_complete: bool,
502}
503
504pub struct LanguageModelTextStream {
505 pub message_id: Option<String>,
506 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
507 // Has complete token usage after the stream has finished
508 pub last_token_usage: Arc<Mutex<TokenUsage>>,
509}
510
511impl Default for LanguageModelTextStream {
512 fn default() -> Self {
513 Self {
514 message_id: None,
515 stream: Box::pin(futures::stream::empty()),
516 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
517 }
518 }
519}
520
521pub trait LanguageModel: Send + Sync {
522 fn id(&self) -> LanguageModelId;
523 fn name(&self) -> LanguageModelName;
524 fn provider_id(&self) -> LanguageModelProviderId;
525 fn provider_name(&self) -> LanguageModelProviderName;
526 fn upstream_provider_id(&self) -> LanguageModelProviderId {
527 self.provider_id()
528 }
529 fn upstream_provider_name(&self) -> LanguageModelProviderName {
530 self.provider_name()
531 }
532
533 fn telemetry_id(&self) -> String;
534
535 fn api_key(&self, _cx: &App) -> Option<String> {
536 None
537 }
538
539 /// Whether this model supports images
540 fn supports_images(&self) -> bool;
541
542 /// Whether this model supports tools.
543 fn supports_tools(&self) -> bool;
544
545 /// Whether this model supports choosing which tool to use.
546 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
547
548 /// Returns whether this model supports "burn mode";
549 fn supports_burn_mode(&self) -> bool {
550 false
551 }
552
553 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
554 LanguageModelToolSchemaFormat::JsonSchema
555 }
556
557 fn max_token_count(&self) -> u64;
558 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
559 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
560 None
561 }
562 fn max_output_tokens(&self) -> Option<u64> {
563 None
564 }
565
566 fn count_tokens(
567 &self,
568 request: LanguageModelRequest,
569 cx: &App,
570 ) -> BoxFuture<'static, Result<u64>>;
571
572 fn stream_completion(
573 &self,
574 request: LanguageModelRequest,
575 cx: &AsyncApp,
576 ) -> BoxFuture<
577 'static,
578 Result<
579 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
580 LanguageModelCompletionError,
581 >,
582 >;
583
584 fn stream_completion_text(
585 &self,
586 request: LanguageModelRequest,
587 cx: &AsyncApp,
588 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
589 let future = self.stream_completion(request, cx);
590
591 async move {
592 let events = future.await?;
593 let mut events = events.fuse();
594 let mut message_id = None;
595 let mut first_item_text = None;
596 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
597
598 if let Some(first_event) = events.next().await {
599 match first_event {
600 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
601 message_id = Some(id);
602 }
603 Ok(LanguageModelCompletionEvent::Text(text)) => {
604 first_item_text = Some(text);
605 }
606 _ => (),
607 }
608 }
609
610 let stream = futures::stream::iter(first_item_text.map(Ok))
611 .chain(events.filter_map({
612 let last_token_usage = last_token_usage.clone();
613 move |result| {
614 let last_token_usage = last_token_usage.clone();
615 async move {
616 match result {
617 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
618 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
619 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
620 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
621 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
622 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
623 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
624 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
625 ..
626 }) => None,
627 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
628 *last_token_usage.lock() = token_usage;
629 None
630 }
631 Err(err) => Some(Err(err)),
632 }
633 }
634 }
635 }))
636 .boxed();
637
638 Ok(LanguageModelTextStream {
639 message_id,
640 stream,
641 last_token_usage,
642 })
643 }
644 .boxed()
645 }
646
647 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
648 None
649 }
650
651 #[cfg(any(test, feature = "test-support"))]
652 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
653 unimplemented!()
654 }
655}
656
657pub trait LanguageModelExt: LanguageModel {
658 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
659 match mode {
660 CompletionMode::Normal => self.max_token_count(),
661 CompletionMode::Max => self
662 .max_token_count_in_burn_mode()
663 .unwrap_or_else(|| self.max_token_count()),
664 }
665 }
666}
667impl LanguageModelExt for dyn LanguageModel {}
668
669pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
670 fn name() -> String;
671 fn description() -> String;
672}
673
674/// An error that occurred when trying to authenticate the language model provider.
675#[derive(Debug, Error)]
676pub enum AuthenticateError {
677 #[error("connection refused")]
678 ConnectionRefused,
679 #[error("credentials not found")]
680 CredentialsNotFound,
681 #[error(transparent)]
682 Other(#[from] anyhow::Error),
683}
684
685pub trait LanguageModelProvider: 'static {
686 fn id(&self) -> LanguageModelProviderId;
687 fn name(&self) -> LanguageModelProviderName;
688 fn icon(&self) -> IconName {
689 IconName::ZedAssistant
690 }
691 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
692 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
693 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
694 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
695 Vec::new()
696 }
697 fn is_authenticated(&self, cx: &App) -> bool;
698 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
699 fn configuration_view(
700 &self,
701 target_agent: ConfigurationViewTargetAgent,
702 window: &mut Window,
703 cx: &mut App,
704 ) -> AnyView;
705 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
706}
707
708#[derive(Default, Clone)]
709pub enum ConfigurationViewTargetAgent {
710 #[default]
711 ZedAgent,
712 Other(SharedString),
713}
714
715#[derive(PartialEq, Eq)]
716pub enum LanguageModelProviderTosView {
717 /// When there are some past interactions in the Agent Panel.
718 ThreadEmptyState,
719 /// When there are no past interactions in the Agent Panel.
720 ThreadFreshStart,
721 TextThreadPopup,
722 Configuration,
723}
724
725pub trait LanguageModelProviderState: 'static {
726 type ObservableEntity;
727
728 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
729
730 fn subscribe<T: 'static>(
731 &self,
732 cx: &mut gpui::Context<T>,
733 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
734 ) -> Option<gpui::Subscription> {
735 let entity = self.observable_entity()?;
736 Some(cx.observe(&entity, move |this, _, cx| {
737 callback(this, cx);
738 }))
739 }
740}
741
742#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
743pub struct LanguageModelId(pub SharedString);
744
745#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
746pub struct LanguageModelName(pub SharedString);
747
748#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
749pub struct LanguageModelProviderId(pub SharedString);
750
751#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
752pub struct LanguageModelProviderName(pub SharedString);
753
754impl LanguageModelProviderId {
755 pub const fn new(id: &'static str) -> Self {
756 Self(SharedString::new_static(id))
757 }
758}
759
760impl LanguageModelProviderName {
761 pub const fn new(id: &'static str) -> Self {
762 Self(SharedString::new_static(id))
763 }
764}
765
766impl fmt::Display for LanguageModelProviderId {
767 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
768 write!(f, "{}", self.0)
769 }
770}
771
772impl fmt::Display for LanguageModelProviderName {
773 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
774 write!(f, "{}", self.0)
775 }
776}
777
778impl From<String> for LanguageModelId {
779 fn from(value: String) -> Self {
780 Self(SharedString::from(value))
781 }
782}
783
784impl From<String> for LanguageModelName {
785 fn from(value: String) -> Self {
786 Self(SharedString::from(value))
787 }
788}
789
790impl From<String> for LanguageModelProviderId {
791 fn from(value: String) -> Self {
792 Self(SharedString::from(value))
793 }
794}
795
796impl From<String> for LanguageModelProviderName {
797 fn from(value: String) -> Self {
798 Self(SharedString::from(value))
799 }
800}
801
802impl From<Arc<str>> for LanguageModelProviderId {
803 fn from(value: Arc<str>) -> Self {
804 Self(SharedString::from(value))
805 }
806}
807
808impl From<Arc<str>> for LanguageModelProviderName {
809 fn from(value: Arc<str>) -> Self {
810 Self(SharedString::from(value))
811 }
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817
818 #[test]
819 fn test_from_cloud_failure_with_upstream_http_error() {
820 let error = LanguageModelCompletionError::from_cloud_failure(
821 String::from("anthropic").into(),
822 "upstream_http_error".to_string(),
823 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
824 None,
825 );
826
827 match error {
828 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
829 assert_eq!(provider.0, "anthropic");
830 }
831 _ => panic!(
832 "Expected ServerOverloaded error for 503 status, got: {:?}",
833 error
834 ),
835 }
836
837 let error = LanguageModelCompletionError::from_cloud_failure(
838 String::from("anthropic").into(),
839 "upstream_http_error".to_string(),
840 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
841 None,
842 );
843
844 match error {
845 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
846 assert_eq!(provider.0, "anthropic");
847 assert_eq!(message, "Internal server error");
848 }
849 _ => panic!(
850 "Expected ApiInternalServerError for 500 status, got: {:?}",
851 error
852 ),
853 }
854 }
855
856 #[test]
857 fn test_from_cloud_failure_with_standard_format() {
858 let error = LanguageModelCompletionError::from_cloud_failure(
859 String::from("anthropic").into(),
860 "upstream_http_503".to_string(),
861 "Service unavailable".to_string(),
862 None,
863 );
864
865 match error {
866 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
867 assert_eq!(provider.0, "anthropic");
868 }
869 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
870 }
871 }
872
873 #[test]
874 fn test_upstream_http_error_connection_timeout() {
875 let error = LanguageModelCompletionError::from_cloud_failure(
876 String::from("anthropic").into(),
877 "upstream_http_error".to_string(),
878 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
879 None,
880 );
881
882 match error {
883 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
884 assert_eq!(provider.0, "anthropic");
885 }
886 _ => panic!(
887 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
888 error
889 ),
890 }
891
892 let error = LanguageModelCompletionError::from_cloud_failure(
893 String::from("anthropic").into(),
894 "upstream_http_error".to_string(),
895 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
896 None,
897 );
898
899 match error {
900 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
901 assert_eq!(provider.0, "anthropic");
902 assert_eq!(
903 message,
904 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
905 );
906 }
907 _ => panic!(
908 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
909 error
910 ),
911 }
912 }
913}