1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::{Result, anyhow};
13use client::Client;
14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
15use futures::FutureExt;
16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
18use http_client::{StatusCode, http};
19use icons::IconName;
20use open_router::OpenRouterError;
21use parking_lot::Mutex;
22use serde::{Deserialize, Serialize};
23pub use settings::LanguageModelCacheConfiguration;
24use std::ops::{Add, Sub};
25use std::str::FromStr;
26use std::sync::Arc;
27use std::time::Duration;
28use std::{fmt, io};
29use thiserror::Error;
30use util::serde::is_default;
31
32pub use crate::model::*;
33pub use crate::rate_limiter::*;
34pub use crate::registry::*;
35pub use crate::request::*;
36pub use crate::role::*;
37pub use crate::telemetry::*;
38
39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
40 LanguageModelProviderId::new("anthropic");
41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
42 LanguageModelProviderName::new("Anthropic");
43
44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
46 LanguageModelProviderName::new("Google AI");
47
48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
50 LanguageModelProviderName::new("OpenAI");
51
52pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
53pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
54
55pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
56pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
57 LanguageModelProviderName::new("Zed");
58
59pub fn init(client: Arc<Client>, cx: &mut App) {
60 init_settings(cx);
61 RefreshLlmTokenListener::register(client, cx);
62}
63
64pub fn init_settings(cx: &mut App) {
65 registry::init(cx);
66}
67
68/// A completion event from a language model.
69#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
70pub enum LanguageModelCompletionEvent {
71 StatusUpdate(CompletionRequestStatus),
72 Stop(StopReason),
73 Text(String),
74 Thinking {
75 text: String,
76 signature: Option<String>,
77 },
78 RedactedThinking {
79 data: String,
80 },
81 ToolUse(LanguageModelToolUse),
82 ToolUseJsonParseError {
83 id: LanguageModelToolUseId,
84 tool_name: Arc<str>,
85 raw_input: Arc<str>,
86 json_parse_error: String,
87 },
88 StartMessage {
89 message_id: String,
90 },
91 UsageUpdate(TokenUsage),
92}
93
94#[derive(Error, Debug)]
95pub enum LanguageModelCompletionError {
96 #[error("prompt too large for context window")]
97 PromptTooLarge { tokens: Option<u64> },
98 #[error("missing {provider} API key")]
99 NoApiKey { provider: LanguageModelProviderName },
100 #[error("{provider}'s API rate limit exceeded")]
101 RateLimitExceeded {
102 provider: LanguageModelProviderName,
103 retry_after: Option<Duration>,
104 },
105 #[error("{provider}'s API servers are overloaded right now")]
106 ServerOverloaded {
107 provider: LanguageModelProviderName,
108 retry_after: Option<Duration>,
109 },
110 #[error("{provider}'s API server reported an internal server error: {message}")]
111 ApiInternalServerError {
112 provider: LanguageModelProviderName,
113 message: String,
114 },
115 #[error("{message}")]
116 UpstreamProviderError {
117 message: String,
118 status: StatusCode,
119 retry_after: Option<Duration>,
120 },
121 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
122 HttpResponseError {
123 provider: LanguageModelProviderName,
124 status_code: StatusCode,
125 message: String,
126 },
127
128 // Client errors
129 #[error("invalid request format to {provider}'s API: {message}")]
130 BadRequestFormat {
131 provider: LanguageModelProviderName,
132 message: String,
133 },
134 #[error("authentication error with {provider}'s API: {message}")]
135 AuthenticationError {
136 provider: LanguageModelProviderName,
137 message: String,
138 },
139 #[error("permission error with {provider}'s API: {message}")]
140 PermissionError {
141 provider: LanguageModelProviderName,
142 message: String,
143 },
144 #[error("language model provider API endpoint not found")]
145 ApiEndpointNotFound { provider: LanguageModelProviderName },
146 #[error("I/O error reading response from {provider}'s API")]
147 ApiReadResponseError {
148 provider: LanguageModelProviderName,
149 #[source]
150 error: io::Error,
151 },
152 #[error("error serializing request to {provider} API")]
153 SerializeRequest {
154 provider: LanguageModelProviderName,
155 #[source]
156 error: serde_json::Error,
157 },
158 #[error("error building request body to {provider} API")]
159 BuildRequestBody {
160 provider: LanguageModelProviderName,
161 #[source]
162 error: http::Error,
163 },
164 #[error("error sending HTTP request to {provider} API")]
165 HttpSend {
166 provider: LanguageModelProviderName,
167 #[source]
168 error: anyhow::Error,
169 },
170 #[error("error deserializing {provider} API response")]
171 DeserializeResponse {
172 provider: LanguageModelProviderName,
173 #[source]
174 error: serde_json::Error,
175 },
176
177 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
178 #[error(transparent)]
179 Other(#[from] anyhow::Error),
180}
181
182impl LanguageModelCompletionError {
183 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
184 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
185 let upstream_status = error_json
186 .get("upstream_status")
187 .and_then(|v| v.as_u64())
188 .and_then(|status| u16::try_from(status).ok())
189 .and_then(|status| StatusCode::from_u16(status).ok())?;
190 let inner_message = error_json
191 .get("message")
192 .and_then(|v| v.as_str())
193 .unwrap_or(message)
194 .to_string();
195 Some((upstream_status, inner_message))
196 }
197
198 pub fn from_cloud_failure(
199 upstream_provider: LanguageModelProviderName,
200 code: String,
201 message: String,
202 retry_after: Option<Duration>,
203 ) -> Self {
204 if let Some(tokens) = parse_prompt_too_long(&message) {
205 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
206 // to be reported. This is a temporary workaround to handle this in the case where the
207 // token limit has been exceeded.
208 Self::PromptTooLarge {
209 tokens: Some(tokens),
210 }
211 } else if code == "upstream_http_error" {
212 if let Some((upstream_status, inner_message)) =
213 Self::parse_upstream_error_json(&message)
214 {
215 return Self::from_http_status(
216 upstream_provider,
217 upstream_status,
218 inner_message,
219 retry_after,
220 );
221 }
222 anyhow!("completion request failed, code: {code}, message: {message}").into()
223 } else if let Some(status_code) = code
224 .strip_prefix("upstream_http_")
225 .and_then(|code| StatusCode::from_str(code).ok())
226 {
227 Self::from_http_status(upstream_provider, status_code, message, retry_after)
228 } else if let Some(status_code) = code
229 .strip_prefix("http_")
230 .and_then(|code| StatusCode::from_str(code).ok())
231 {
232 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
233 } else {
234 anyhow!("completion request failed, code: {code}, message: {message}").into()
235 }
236 }
237
238 pub fn from_http_status(
239 provider: LanguageModelProviderName,
240 status_code: StatusCode,
241 message: String,
242 retry_after: Option<Duration>,
243 ) -> Self {
244 match status_code {
245 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
246 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
247 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
248 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
249 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
250 tokens: parse_prompt_too_long(&message),
251 },
252 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
253 provider,
254 retry_after,
255 },
256 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
257 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
258 provider,
259 retry_after,
260 },
261 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
262 provider,
263 retry_after,
264 },
265 _ => Self::HttpResponseError {
266 provider,
267 status_code,
268 message,
269 },
270 }
271 }
272}
273
274impl From<AnthropicError> for LanguageModelCompletionError {
275 fn from(error: AnthropicError) -> Self {
276 let provider = ANTHROPIC_PROVIDER_NAME;
277 match error {
278 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
279 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
280 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
281 AnthropicError::DeserializeResponse(error) => {
282 Self::DeserializeResponse { provider, error }
283 }
284 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
285 AnthropicError::HttpResponseError {
286 status_code,
287 message,
288 } => Self::HttpResponseError {
289 provider,
290 status_code,
291 message,
292 },
293 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
294 provider,
295 retry_after: Some(retry_after),
296 },
297 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
298 provider,
299 retry_after,
300 },
301 AnthropicError::ApiError(api_error) => api_error.into(),
302 }
303 }
304}
305
306impl From<anthropic::ApiError> for LanguageModelCompletionError {
307 fn from(error: anthropic::ApiError) -> Self {
308 use anthropic::ApiErrorCode::*;
309 let provider = ANTHROPIC_PROVIDER_NAME;
310 match error.code() {
311 Some(code) => match code {
312 InvalidRequestError => Self::BadRequestFormat {
313 provider,
314 message: error.message,
315 },
316 AuthenticationError => Self::AuthenticationError {
317 provider,
318 message: error.message,
319 },
320 PermissionError => Self::PermissionError {
321 provider,
322 message: error.message,
323 },
324 NotFoundError => Self::ApiEndpointNotFound { provider },
325 RequestTooLarge => Self::PromptTooLarge {
326 tokens: parse_prompt_too_long(&error.message),
327 },
328 RateLimitError => Self::RateLimitExceeded {
329 provider,
330 retry_after: None,
331 },
332 ApiError => Self::ApiInternalServerError {
333 provider,
334 message: error.message,
335 },
336 OverloadedError => Self::ServerOverloaded {
337 provider,
338 retry_after: None,
339 },
340 },
341 None => Self::Other(error.into()),
342 }
343 }
344}
345
346impl From<OpenRouterError> for LanguageModelCompletionError {
347 fn from(error: OpenRouterError) -> Self {
348 let provider = LanguageModelProviderName::new("OpenRouter");
349 match error {
350 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
351 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
352 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
353 OpenRouterError::DeserializeResponse(error) => {
354 Self::DeserializeResponse { provider, error }
355 }
356 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
357 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
358 provider,
359 retry_after: Some(retry_after),
360 },
361 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
362 provider,
363 retry_after,
364 },
365 OpenRouterError::ApiError(api_error) => api_error.into(),
366 }
367 }
368}
369
370impl From<open_router::ApiError> for LanguageModelCompletionError {
371 fn from(error: open_router::ApiError) -> Self {
372 use open_router::ApiErrorCode::*;
373 let provider = LanguageModelProviderName::new("OpenRouter");
374 match error.code {
375 InvalidRequestError => Self::BadRequestFormat {
376 provider,
377 message: error.message,
378 },
379 AuthenticationError => Self::AuthenticationError {
380 provider,
381 message: error.message,
382 },
383 PaymentRequiredError => Self::AuthenticationError {
384 provider,
385 message: format!("Payment required: {}", error.message),
386 },
387 PermissionError => Self::PermissionError {
388 provider,
389 message: error.message,
390 },
391 RequestTimedOut => Self::HttpResponseError {
392 provider,
393 status_code: StatusCode::REQUEST_TIMEOUT,
394 message: error.message,
395 },
396 RateLimitError => Self::RateLimitExceeded {
397 provider,
398 retry_after: None,
399 },
400 ApiError => Self::ApiInternalServerError {
401 provider,
402 message: error.message,
403 },
404 OverloadedError => Self::ServerOverloaded {
405 provider,
406 retry_after: None,
407 },
408 }
409 }
410}
411
412/// Indicates the format used to define the input schema for a language model tool.
413#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
414pub enum LanguageModelToolSchemaFormat {
415 /// A JSON schema, see https://json-schema.org
416 JsonSchema,
417 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
418 JsonSchemaSubset,
419}
420
421#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
422#[serde(rename_all = "snake_case")]
423pub enum StopReason {
424 EndTurn,
425 MaxTokens,
426 ToolUse,
427 Refusal,
428}
429
430#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
431pub struct TokenUsage {
432 #[serde(default, skip_serializing_if = "is_default")]
433 pub input_tokens: u64,
434 #[serde(default, skip_serializing_if = "is_default")]
435 pub output_tokens: u64,
436 #[serde(default, skip_serializing_if = "is_default")]
437 pub cache_creation_input_tokens: u64,
438 #[serde(default, skip_serializing_if = "is_default")]
439 pub cache_read_input_tokens: u64,
440}
441
442impl TokenUsage {
443 pub const fn total_tokens(&self) -> u64 {
444 self.input_tokens
445 + self.output_tokens
446 + self.cache_read_input_tokens
447 + self.cache_creation_input_tokens
448 }
449}
450
451impl Add<TokenUsage> for TokenUsage {
452 type Output = Self;
453
454 fn add(self, other: Self) -> Self {
455 Self {
456 input_tokens: self.input_tokens + other.input_tokens,
457 output_tokens: self.output_tokens + other.output_tokens,
458 cache_creation_input_tokens: self.cache_creation_input_tokens
459 + other.cache_creation_input_tokens,
460 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
461 }
462 }
463}
464
465impl Sub<TokenUsage> for TokenUsage {
466 type Output = Self;
467
468 fn sub(self, other: Self) -> Self {
469 Self {
470 input_tokens: self.input_tokens - other.input_tokens,
471 output_tokens: self.output_tokens - other.output_tokens,
472 cache_creation_input_tokens: self.cache_creation_input_tokens
473 - other.cache_creation_input_tokens,
474 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
475 }
476 }
477}
478
479#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
480pub struct LanguageModelToolUseId(Arc<str>);
481
482impl fmt::Display for LanguageModelToolUseId {
483 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
484 write!(f, "{}", self.0)
485 }
486}
487
488impl<T> From<T> for LanguageModelToolUseId
489where
490 T: Into<Arc<str>>,
491{
492 fn from(value: T) -> Self {
493 Self(value.into())
494 }
495}
496
497#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
498pub struct LanguageModelToolUse {
499 pub id: LanguageModelToolUseId,
500 pub name: Arc<str>,
501 pub raw_input: String,
502 pub input: serde_json::Value,
503 pub is_input_complete: bool,
504}
505
506pub struct LanguageModelTextStream {
507 pub message_id: Option<String>,
508 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
509 // Has complete token usage after the stream has finished
510 pub last_token_usage: Arc<Mutex<TokenUsage>>,
511}
512
513impl Default for LanguageModelTextStream {
514 fn default() -> Self {
515 Self {
516 message_id: None,
517 stream: Box::pin(futures::stream::empty()),
518 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
519 }
520 }
521}
522
523pub trait LanguageModel: Send + Sync {
524 fn id(&self) -> LanguageModelId;
525 fn name(&self) -> LanguageModelName;
526 fn provider_id(&self) -> LanguageModelProviderId;
527 fn provider_name(&self) -> LanguageModelProviderName;
528 fn upstream_provider_id(&self) -> LanguageModelProviderId {
529 self.provider_id()
530 }
531 fn upstream_provider_name(&self) -> LanguageModelProviderName {
532 self.provider_name()
533 }
534
535 fn telemetry_id(&self) -> String;
536
537 fn api_key(&self, _cx: &App) -> Option<String> {
538 None
539 }
540
541 /// Whether this model supports images
542 fn supports_images(&self) -> bool;
543
544 /// Whether this model supports tools.
545 fn supports_tools(&self) -> bool;
546
547 /// Whether this model supports choosing which tool to use.
548 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
549
550 /// Returns whether this model supports "burn mode";
551 fn supports_burn_mode(&self) -> bool {
552 false
553 }
554
555 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
556 LanguageModelToolSchemaFormat::JsonSchema
557 }
558
559 fn max_token_count(&self) -> u64;
560 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
561 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
562 None
563 }
564 fn max_output_tokens(&self) -> Option<u64> {
565 None
566 }
567
568 fn count_tokens(
569 &self,
570 request: LanguageModelRequest,
571 cx: &App,
572 ) -> BoxFuture<'static, Result<u64>>;
573
574 fn stream_completion(
575 &self,
576 request: LanguageModelRequest,
577 cx: &AsyncApp,
578 ) -> BoxFuture<
579 'static,
580 Result<
581 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
582 LanguageModelCompletionError,
583 >,
584 >;
585
586 fn stream_completion_text(
587 &self,
588 request: LanguageModelRequest,
589 cx: &AsyncApp,
590 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
591 let future = self.stream_completion(request, cx);
592
593 async move {
594 let events = future.await?;
595 let mut events = events.fuse();
596 let mut message_id = None;
597 let mut first_item_text = None;
598 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
599
600 if let Some(first_event) = events.next().await {
601 match first_event {
602 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
603 message_id = Some(id);
604 }
605 Ok(LanguageModelCompletionEvent::Text(text)) => {
606 first_item_text = Some(text);
607 }
608 _ => (),
609 }
610 }
611
612 let stream = futures::stream::iter(first_item_text.map(Ok))
613 .chain(events.filter_map({
614 let last_token_usage = last_token_usage.clone();
615 move |result| {
616 let last_token_usage = last_token_usage.clone();
617 async move {
618 match result {
619 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
620 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
621 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
622 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
623 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
624 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
625 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
626 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
627 ..
628 }) => None,
629 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
630 *last_token_usage.lock() = token_usage;
631 None
632 }
633 Err(err) => Some(Err(err)),
634 }
635 }
636 }
637 }))
638 .boxed();
639
640 Ok(LanguageModelTextStream {
641 message_id,
642 stream,
643 last_token_usage,
644 })
645 }
646 .boxed()
647 }
648
649 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
650 None
651 }
652
653 #[cfg(any(test, feature = "test-support"))]
654 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
655 unimplemented!()
656 }
657}
658
659pub trait LanguageModelExt: LanguageModel {
660 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
661 match mode {
662 CompletionMode::Normal => self.max_token_count(),
663 CompletionMode::Max => self
664 .max_token_count_in_burn_mode()
665 .unwrap_or_else(|| self.max_token_count()),
666 }
667 }
668}
669impl LanguageModelExt for dyn LanguageModel {}
670
671/// An error that occurred when trying to authenticate the language model provider.
672#[derive(Debug, Error)]
673pub enum AuthenticateError {
674 #[error("connection refused")]
675 ConnectionRefused,
676 #[error("credentials not found")]
677 CredentialsNotFound,
678 #[error(transparent)]
679 Other(#[from] anyhow::Error),
680}
681
682pub trait LanguageModelProvider: 'static {
683 fn id(&self) -> LanguageModelProviderId;
684 fn name(&self) -> LanguageModelProviderName;
685 fn icon(&self) -> IconName {
686 IconName::ZedAssistant
687 }
688 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
689 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
690 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
691 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
692 Vec::new()
693 }
694 fn is_authenticated(&self, cx: &App) -> bool;
695 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
696 fn configuration_view(
697 &self,
698 target_agent: ConfigurationViewTargetAgent,
699 window: &mut Window,
700 cx: &mut App,
701 ) -> AnyView;
702 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
703}
704
705#[derive(Default, Clone)]
706pub enum ConfigurationViewTargetAgent {
707 #[default]
708 ZedAgent,
709 Other(SharedString),
710}
711
712#[derive(PartialEq, Eq)]
713pub enum LanguageModelProviderTosView {
714 /// When there are some past interactions in the Agent Panel.
715 ThreadEmptyState,
716 /// When there are no past interactions in the Agent Panel.
717 ThreadFreshStart,
718 TextThreadPopup,
719 Configuration,
720}
721
722pub trait LanguageModelProviderState: 'static {
723 type ObservableEntity;
724
725 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
726
727 fn subscribe<T: 'static>(
728 &self,
729 cx: &mut gpui::Context<T>,
730 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
731 ) -> Option<gpui::Subscription> {
732 let entity = self.observable_entity()?;
733 Some(cx.observe(&entity, move |this, _, cx| {
734 callback(this, cx);
735 }))
736 }
737}
738
739#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
740pub struct LanguageModelId(pub SharedString);
741
742#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
743pub struct LanguageModelName(pub SharedString);
744
745#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
746pub struct LanguageModelProviderId(pub SharedString);
747
748#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
749pub struct LanguageModelProviderName(pub SharedString);
750
751impl LanguageModelProviderId {
752 pub const fn new(id: &'static str) -> Self {
753 Self(SharedString::new_static(id))
754 }
755}
756
757impl LanguageModelProviderName {
758 pub const fn new(id: &'static str) -> Self {
759 Self(SharedString::new_static(id))
760 }
761}
762
763impl fmt::Display for LanguageModelProviderId {
764 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
765 write!(f, "{}", self.0)
766 }
767}
768
769impl fmt::Display for LanguageModelProviderName {
770 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
771 write!(f, "{}", self.0)
772 }
773}
774
775impl From<String> for LanguageModelId {
776 fn from(value: String) -> Self {
777 Self(SharedString::from(value))
778 }
779}
780
781impl From<String> for LanguageModelName {
782 fn from(value: String) -> Self {
783 Self(SharedString::from(value))
784 }
785}
786
787impl From<String> for LanguageModelProviderId {
788 fn from(value: String) -> Self {
789 Self(SharedString::from(value))
790 }
791}
792
793impl From<String> for LanguageModelProviderName {
794 fn from(value: String) -> Self {
795 Self(SharedString::from(value))
796 }
797}
798
799impl From<Arc<str>> for LanguageModelProviderId {
800 fn from(value: Arc<str>) -> Self {
801 Self(SharedString::from(value))
802 }
803}
804
805impl From<Arc<str>> for LanguageModelProviderName {
806 fn from(value: Arc<str>) -> Self {
807 Self(SharedString::from(value))
808 }
809}
810
811#[cfg(test)]
812mod tests {
813 use super::*;
814
815 #[test]
816 fn test_from_cloud_failure_with_upstream_http_error() {
817 let error = LanguageModelCompletionError::from_cloud_failure(
818 String::from("anthropic").into(),
819 "upstream_http_error".to_string(),
820 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
821 None,
822 );
823
824 match error {
825 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
826 assert_eq!(provider.0, "anthropic");
827 }
828 _ => panic!(
829 "Expected ServerOverloaded error for 503 status, got: {:?}",
830 error
831 ),
832 }
833
834 let error = LanguageModelCompletionError::from_cloud_failure(
835 String::from("anthropic").into(),
836 "upstream_http_error".to_string(),
837 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
838 None,
839 );
840
841 match error {
842 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
843 assert_eq!(provider.0, "anthropic");
844 assert_eq!(message, "Internal server error");
845 }
846 _ => panic!(
847 "Expected ApiInternalServerError for 500 status, got: {:?}",
848 error
849 ),
850 }
851 }
852
853 #[test]
854 fn test_from_cloud_failure_with_standard_format() {
855 let error = LanguageModelCompletionError::from_cloud_failure(
856 String::from("anthropic").into(),
857 "upstream_http_503".to_string(),
858 "Service unavailable".to_string(),
859 None,
860 );
861
862 match error {
863 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
864 assert_eq!(provider.0, "anthropic");
865 }
866 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
867 }
868 }
869
870 #[test]
871 fn test_upstream_http_error_connection_timeout() {
872 let error = LanguageModelCompletionError::from_cloud_failure(
873 String::from("anthropic").into(),
874 "upstream_http_error".to_string(),
875 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
876 None,
877 );
878
879 match error {
880 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
881 assert_eq!(provider.0, "anthropic");
882 }
883 _ => panic!(
884 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
885 error
886 ),
887 }
888
889 let error = LanguageModelCompletionError::from_cloud_failure(
890 String::from("anthropic").into(),
891 "upstream_http_error".to_string(),
892 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
893 None,
894 );
895
896 match error {
897 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
898 assert_eq!(provider.0, "anthropic");
899 assert_eq!(
900 message,
901 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
902 );
903 }
904 _ => panic!(
905 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
906 error
907 ),
908 }
909 }
910}