1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::{Result, anyhow};
13use client::Client;
14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
15use futures::FutureExt;
16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
18use http_client::{StatusCode, http};
19use icons::IconName;
20use open_router::OpenRouterError;
21use parking_lot::Mutex;
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize, de::DeserializeOwned};
24use std::ops::{Add, Sub};
25use std::str::FromStr;
26use std::sync::Arc;
27use std::time::Duration;
28use std::{fmt, io};
29use thiserror::Error;
30use util::serde::is_default;
31
32pub use crate::model::*;
33pub use crate::rate_limiter::*;
34pub use crate::registry::*;
35pub use crate::request::*;
36pub use crate::role::*;
37pub use crate::telemetry::*;
38
39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
40 LanguageModelProviderId::new("anthropic");
41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
42 LanguageModelProviderName::new("Anthropic");
43
44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
46 LanguageModelProviderName::new("Google AI");
47
48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
50 LanguageModelProviderName::new("OpenAI");
51
52pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
53pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
54 LanguageModelProviderName::new("Zed");
55
56pub fn init(client: Arc<Client>, cx: &mut App) {
57 init_settings(cx);
58 RefreshLlmTokenListener::register(client, cx);
59}
60
61pub fn init_settings(cx: &mut App) {
62 registry::init(cx);
63}
64
65/// Configuration for caching language model messages.
66#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
67pub struct LanguageModelCacheConfiguration {
68 pub max_cache_anchors: usize,
69 pub should_speculate: bool,
70 pub min_total_token: u64,
71}
72
73/// A completion event from a language model.
74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
75pub enum LanguageModelCompletionEvent {
76 StatusUpdate(CompletionRequestStatus),
77 Stop(StopReason),
78 Text(String),
79 Thinking {
80 text: String,
81 signature: Option<String>,
82 },
83 RedactedThinking {
84 data: String,
85 },
86 ToolUse(LanguageModelToolUse),
87 ToolUseJsonParseError {
88 id: LanguageModelToolUseId,
89 tool_name: Arc<str>,
90 raw_input: Arc<str>,
91 json_parse_error: String,
92 },
93 StartMessage {
94 message_id: String,
95 },
96 UsageUpdate(TokenUsage),
97}
98
99#[derive(Error, Debug)]
100pub enum LanguageModelCompletionError {
101 #[error("prompt too large for context window")]
102 PromptTooLarge { tokens: Option<u64> },
103 #[error("missing {provider} API key")]
104 NoApiKey { provider: LanguageModelProviderName },
105 #[error("{provider}'s API rate limit exceeded")]
106 RateLimitExceeded {
107 provider: LanguageModelProviderName,
108 retry_after: Option<Duration>,
109 },
110 #[error("{provider}'s API servers are overloaded right now")]
111 ServerOverloaded {
112 provider: LanguageModelProviderName,
113 retry_after: Option<Duration>,
114 },
115 #[error("{provider}'s API server reported an internal server error: {message}")]
116 ApiInternalServerError {
117 provider: LanguageModelProviderName,
118 message: String,
119 },
120 #[error("{message}")]
121 UpstreamProviderError {
122 message: String,
123 status: StatusCode,
124 retry_after: Option<Duration>,
125 },
126 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
127 HttpResponseError {
128 provider: LanguageModelProviderName,
129 status_code: StatusCode,
130 message: String,
131 },
132
133 // Client errors
134 #[error("invalid request format to {provider}'s API: {message}")]
135 BadRequestFormat {
136 provider: LanguageModelProviderName,
137 message: String,
138 },
139 #[error("authentication error with {provider}'s API: {message}")]
140 AuthenticationError {
141 provider: LanguageModelProviderName,
142 message: String,
143 },
144 #[error("permission error with {provider}'s API: {message}")]
145 PermissionError {
146 provider: LanguageModelProviderName,
147 message: String,
148 },
149 #[error("language model provider API endpoint not found")]
150 ApiEndpointNotFound { provider: LanguageModelProviderName },
151 #[error("I/O error reading response from {provider}'s API")]
152 ApiReadResponseError {
153 provider: LanguageModelProviderName,
154 #[source]
155 error: io::Error,
156 },
157 #[error("error serializing request to {provider} API")]
158 SerializeRequest {
159 provider: LanguageModelProviderName,
160 #[source]
161 error: serde_json::Error,
162 },
163 #[error("error building request body to {provider} API")]
164 BuildRequestBody {
165 provider: LanguageModelProviderName,
166 #[source]
167 error: http::Error,
168 },
169 #[error("error sending HTTP request to {provider} API")]
170 HttpSend {
171 provider: LanguageModelProviderName,
172 #[source]
173 error: anyhow::Error,
174 },
175 #[error("error deserializing {provider} API response")]
176 DeserializeResponse {
177 provider: LanguageModelProviderName,
178 #[source]
179 error: serde_json::Error,
180 },
181
182 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
183 #[error(transparent)]
184 Other(#[from] anyhow::Error),
185}
186
187impl LanguageModelCompletionError {
188 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
189 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
190 let upstream_status = error_json
191 .get("upstream_status")
192 .and_then(|v| v.as_u64())
193 .and_then(|status| u16::try_from(status).ok())
194 .and_then(|status| StatusCode::from_u16(status).ok())?;
195 let inner_message = error_json
196 .get("message")
197 .and_then(|v| v.as_str())
198 .unwrap_or(message)
199 .to_string();
200 Some((upstream_status, inner_message))
201 }
202
203 pub fn from_cloud_failure(
204 upstream_provider: LanguageModelProviderName,
205 code: String,
206 message: String,
207 retry_after: Option<Duration>,
208 ) -> Self {
209 if let Some(tokens) = parse_prompt_too_long(&message) {
210 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
211 // to be reported. This is a temporary workaround to handle this in the case where the
212 // token limit has been exceeded.
213 Self::PromptTooLarge {
214 tokens: Some(tokens),
215 }
216 } else if code == "upstream_http_error" {
217 if let Some((upstream_status, inner_message)) =
218 Self::parse_upstream_error_json(&message)
219 {
220 return Self::from_http_status(
221 upstream_provider,
222 upstream_status,
223 inner_message,
224 retry_after,
225 );
226 }
227 anyhow!("completion request failed, code: {code}, message: {message}").into()
228 } else if let Some(status_code) = code
229 .strip_prefix("upstream_http_")
230 .and_then(|code| StatusCode::from_str(code).ok())
231 {
232 Self::from_http_status(upstream_provider, status_code, message, retry_after)
233 } else if let Some(status_code) = code
234 .strip_prefix("http_")
235 .and_then(|code| StatusCode::from_str(code).ok())
236 {
237 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
238 } else {
239 anyhow!("completion request failed, code: {code}, message: {message}").into()
240 }
241 }
242
243 pub fn from_http_status(
244 provider: LanguageModelProviderName,
245 status_code: StatusCode,
246 message: String,
247 retry_after: Option<Duration>,
248 ) -> Self {
249 match status_code {
250 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
251 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
252 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
253 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
254 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
255 tokens: parse_prompt_too_long(&message),
256 },
257 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
258 provider,
259 retry_after,
260 },
261 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
262 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
263 provider,
264 retry_after,
265 },
266 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
267 provider,
268 retry_after,
269 },
270 _ => Self::HttpResponseError {
271 provider,
272 status_code,
273 message,
274 },
275 }
276 }
277}
278
279impl From<AnthropicError> for LanguageModelCompletionError {
280 fn from(error: AnthropicError) -> Self {
281 let provider = ANTHROPIC_PROVIDER_NAME;
282 match error {
283 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
284 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
285 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
286 AnthropicError::DeserializeResponse(error) => {
287 Self::DeserializeResponse { provider, error }
288 }
289 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
290 AnthropicError::HttpResponseError {
291 status_code,
292 message,
293 } => Self::HttpResponseError {
294 provider,
295 status_code,
296 message,
297 },
298 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
299 provider,
300 retry_after: Some(retry_after),
301 },
302 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
303 provider,
304 retry_after,
305 },
306 AnthropicError::ApiError(api_error) => api_error.into(),
307 }
308 }
309}
310
311impl From<anthropic::ApiError> for LanguageModelCompletionError {
312 fn from(error: anthropic::ApiError) -> Self {
313 use anthropic::ApiErrorCode::*;
314 let provider = ANTHROPIC_PROVIDER_NAME;
315 match error.code() {
316 Some(code) => match code {
317 InvalidRequestError => Self::BadRequestFormat {
318 provider,
319 message: error.message,
320 },
321 AuthenticationError => Self::AuthenticationError {
322 provider,
323 message: error.message,
324 },
325 PermissionError => Self::PermissionError {
326 provider,
327 message: error.message,
328 },
329 NotFoundError => Self::ApiEndpointNotFound { provider },
330 RequestTooLarge => Self::PromptTooLarge {
331 tokens: parse_prompt_too_long(&error.message),
332 },
333 RateLimitError => Self::RateLimitExceeded {
334 provider,
335 retry_after: None,
336 },
337 ApiError => Self::ApiInternalServerError {
338 provider,
339 message: error.message,
340 },
341 OverloadedError => Self::ServerOverloaded {
342 provider,
343 retry_after: None,
344 },
345 },
346 None => Self::Other(error.into()),
347 }
348 }
349}
350
351impl From<OpenRouterError> for LanguageModelCompletionError {
352 fn from(error: OpenRouterError) -> Self {
353 let provider = LanguageModelProviderName::new("OpenRouter");
354 match error {
355 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
356 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
357 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
358 OpenRouterError::DeserializeResponse(error) => {
359 Self::DeserializeResponse { provider, error }
360 }
361 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
362 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
363 provider,
364 retry_after: Some(retry_after),
365 },
366 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
367 provider,
368 retry_after,
369 },
370 OpenRouterError::ApiError(api_error) => api_error.into(),
371 }
372 }
373}
374
375impl From<open_router::ApiError> for LanguageModelCompletionError {
376 fn from(error: open_router::ApiError) -> Self {
377 use open_router::ApiErrorCode::*;
378 let provider = LanguageModelProviderName::new("OpenRouter");
379 match error.code {
380 InvalidRequestError => Self::BadRequestFormat {
381 provider,
382 message: error.message,
383 },
384 AuthenticationError => Self::AuthenticationError {
385 provider,
386 message: error.message,
387 },
388 PaymentRequiredError => Self::AuthenticationError {
389 provider,
390 message: format!("Payment required: {}", error.message),
391 },
392 PermissionError => Self::PermissionError {
393 provider,
394 message: error.message,
395 },
396 RequestTimedOut => Self::HttpResponseError {
397 provider,
398 status_code: StatusCode::REQUEST_TIMEOUT,
399 message: error.message,
400 },
401 RateLimitError => Self::RateLimitExceeded {
402 provider,
403 retry_after: None,
404 },
405 ApiError => Self::ApiInternalServerError {
406 provider,
407 message: error.message,
408 },
409 OverloadedError => Self::ServerOverloaded {
410 provider,
411 retry_after: None,
412 },
413 }
414 }
415}
416
417/// Indicates the format used to define the input schema for a language model tool.
418#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
419pub enum LanguageModelToolSchemaFormat {
420 /// A JSON schema, see https://json-schema.org
421 JsonSchema,
422 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
423 JsonSchemaSubset,
424}
425
426#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
427#[serde(rename_all = "snake_case")]
428pub enum StopReason {
429 EndTurn,
430 MaxTokens,
431 ToolUse,
432 Refusal,
433}
434
435#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
436pub struct TokenUsage {
437 #[serde(default, skip_serializing_if = "is_default")]
438 pub input_tokens: u64,
439 #[serde(default, skip_serializing_if = "is_default")]
440 pub output_tokens: u64,
441 #[serde(default, skip_serializing_if = "is_default")]
442 pub cache_creation_input_tokens: u64,
443 #[serde(default, skip_serializing_if = "is_default")]
444 pub cache_read_input_tokens: u64,
445}
446
447impl TokenUsage {
448 pub fn total_tokens(&self) -> u64 {
449 self.input_tokens
450 + self.output_tokens
451 + self.cache_read_input_tokens
452 + self.cache_creation_input_tokens
453 }
454}
455
456impl Add<TokenUsage> for TokenUsage {
457 type Output = Self;
458
459 fn add(self, other: Self) -> Self {
460 Self {
461 input_tokens: self.input_tokens + other.input_tokens,
462 output_tokens: self.output_tokens + other.output_tokens,
463 cache_creation_input_tokens: self.cache_creation_input_tokens
464 + other.cache_creation_input_tokens,
465 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
466 }
467 }
468}
469
470impl Sub<TokenUsage> for TokenUsage {
471 type Output = Self;
472
473 fn sub(self, other: Self) -> Self {
474 Self {
475 input_tokens: self.input_tokens - other.input_tokens,
476 output_tokens: self.output_tokens - other.output_tokens,
477 cache_creation_input_tokens: self.cache_creation_input_tokens
478 - other.cache_creation_input_tokens,
479 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
480 }
481 }
482}
483
484#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
485pub struct LanguageModelToolUseId(Arc<str>);
486
487impl fmt::Display for LanguageModelToolUseId {
488 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
489 write!(f, "{}", self.0)
490 }
491}
492
493impl<T> From<T> for LanguageModelToolUseId
494where
495 T: Into<Arc<str>>,
496{
497 fn from(value: T) -> Self {
498 Self(value.into())
499 }
500}
501
502#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
503pub struct LanguageModelToolUse {
504 pub id: LanguageModelToolUseId,
505 pub name: Arc<str>,
506 pub raw_input: String,
507 pub input: serde_json::Value,
508 pub is_input_complete: bool,
509}
510
511pub struct LanguageModelTextStream {
512 pub message_id: Option<String>,
513 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
514 // Has complete token usage after the stream has finished
515 pub last_token_usage: Arc<Mutex<TokenUsage>>,
516}
517
518impl Default for LanguageModelTextStream {
519 fn default() -> Self {
520 Self {
521 message_id: None,
522 stream: Box::pin(futures::stream::empty()),
523 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
524 }
525 }
526}
527
528pub trait LanguageModel: Send + Sync {
529 fn id(&self) -> LanguageModelId;
530 fn name(&self) -> LanguageModelName;
531 fn provider_id(&self) -> LanguageModelProviderId;
532 fn provider_name(&self) -> LanguageModelProviderName;
533 fn upstream_provider_id(&self) -> LanguageModelProviderId {
534 self.provider_id()
535 }
536 fn upstream_provider_name(&self) -> LanguageModelProviderName {
537 self.provider_name()
538 }
539
540 fn telemetry_id(&self) -> String;
541
542 fn api_key(&self, _cx: &App) -> Option<String> {
543 None
544 }
545
546 /// Whether this model supports images
547 fn supports_images(&self) -> bool;
548
549 /// Whether this model supports tools.
550 fn supports_tools(&self) -> bool;
551
552 /// Whether this model supports choosing which tool to use.
553 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
554
555 /// Returns whether this model supports "burn mode";
556 fn supports_burn_mode(&self) -> bool {
557 false
558 }
559
560 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
561 LanguageModelToolSchemaFormat::JsonSchema
562 }
563
564 fn max_token_count(&self) -> u64;
565 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
566 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
567 None
568 }
569 fn max_output_tokens(&self) -> Option<u64> {
570 None
571 }
572
573 fn count_tokens(
574 &self,
575 request: LanguageModelRequest,
576 cx: &App,
577 ) -> BoxFuture<'static, Result<u64>>;
578
579 fn stream_completion(
580 &self,
581 request: LanguageModelRequest,
582 cx: &AsyncApp,
583 ) -> BoxFuture<
584 'static,
585 Result<
586 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
587 LanguageModelCompletionError,
588 >,
589 >;
590
591 fn stream_completion_text(
592 &self,
593 request: LanguageModelRequest,
594 cx: &AsyncApp,
595 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
596 let future = self.stream_completion(request, cx);
597
598 async move {
599 let events = future.await?;
600 let mut events = events.fuse();
601 let mut message_id = None;
602 let mut first_item_text = None;
603 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
604
605 if let Some(first_event) = events.next().await {
606 match first_event {
607 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
608 message_id = Some(id);
609 }
610 Ok(LanguageModelCompletionEvent::Text(text)) => {
611 first_item_text = Some(text);
612 }
613 _ => (),
614 }
615 }
616
617 let stream = futures::stream::iter(first_item_text.map(Ok))
618 .chain(events.filter_map({
619 let last_token_usage = last_token_usage.clone();
620 move |result| {
621 let last_token_usage = last_token_usage.clone();
622 async move {
623 match result {
624 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
625 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
626 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
627 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
628 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
629 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
630 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
631 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
632 ..
633 }) => None,
634 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
635 *last_token_usage.lock() = token_usage;
636 None
637 }
638 Err(err) => Some(Err(err)),
639 }
640 }
641 }
642 }))
643 .boxed();
644
645 Ok(LanguageModelTextStream {
646 message_id,
647 stream,
648 last_token_usage,
649 })
650 }
651 .boxed()
652 }
653
654 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
655 None
656 }
657
658 #[cfg(any(test, feature = "test-support"))]
659 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
660 unimplemented!()
661 }
662}
663
664pub trait LanguageModelExt: LanguageModel {
665 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
666 match mode {
667 CompletionMode::Normal => self.max_token_count(),
668 CompletionMode::Max => self
669 .max_token_count_in_burn_mode()
670 .unwrap_or_else(|| self.max_token_count()),
671 }
672 }
673}
674impl LanguageModelExt for dyn LanguageModel {}
675
676pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
677 fn name() -> String;
678 fn description() -> String;
679}
680
681/// An error that occurred when trying to authenticate the language model provider.
682#[derive(Debug, Error)]
683pub enum AuthenticateError {
684 #[error("connection refused")]
685 ConnectionRefused,
686 #[error("credentials not found")]
687 CredentialsNotFound,
688 #[error(transparent)]
689 Other(#[from] anyhow::Error),
690}
691
692pub trait LanguageModelProvider: 'static {
693 fn id(&self) -> LanguageModelProviderId;
694 fn name(&self) -> LanguageModelProviderName;
695 fn icon(&self) -> IconName {
696 IconName::ZedAssistant
697 }
698 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
699 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
700 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
701 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
702 Vec::new()
703 }
704 fn is_authenticated(&self, cx: &App) -> bool;
705 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
706 fn configuration_view(
707 &self,
708 target_agent: ConfigurationViewTargetAgent,
709 window: &mut Window,
710 cx: &mut App,
711 ) -> AnyView;
712 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
713}
714
715#[derive(Default, Clone)]
716pub enum ConfigurationViewTargetAgent {
717 #[default]
718 ZedAgent,
719 Other(SharedString),
720}
721
722#[derive(PartialEq, Eq)]
723pub enum LanguageModelProviderTosView {
724 /// When there are some past interactions in the Agent Panel.
725 ThreadEmptyState,
726 /// When there are no past interactions in the Agent Panel.
727 ThreadFreshStart,
728 TextThreadPopup,
729 Configuration,
730}
731
732pub trait LanguageModelProviderState: 'static {
733 type ObservableEntity;
734
735 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
736
737 fn subscribe<T: 'static>(
738 &self,
739 cx: &mut gpui::Context<T>,
740 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
741 ) -> Option<gpui::Subscription> {
742 let entity = self.observable_entity()?;
743 Some(cx.observe(&entity, move |this, _, cx| {
744 callback(this, cx);
745 }))
746 }
747}
748
749#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
750pub struct LanguageModelId(pub SharedString);
751
752#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
753pub struct LanguageModelName(pub SharedString);
754
755#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
756pub struct LanguageModelProviderId(pub SharedString);
757
758#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
759pub struct LanguageModelProviderName(pub SharedString);
760
761impl LanguageModelProviderId {
762 pub const fn new(id: &'static str) -> Self {
763 Self(SharedString::new_static(id))
764 }
765}
766
767impl LanguageModelProviderName {
768 pub const fn new(id: &'static str) -> Self {
769 Self(SharedString::new_static(id))
770 }
771}
772
773impl fmt::Display for LanguageModelProviderId {
774 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
775 write!(f, "{}", self.0)
776 }
777}
778
779impl fmt::Display for LanguageModelProviderName {
780 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
781 write!(f, "{}", self.0)
782 }
783}
784
785impl From<String> for LanguageModelId {
786 fn from(value: String) -> Self {
787 Self(SharedString::from(value))
788 }
789}
790
791impl From<String> for LanguageModelName {
792 fn from(value: String) -> Self {
793 Self(SharedString::from(value))
794 }
795}
796
797impl From<String> for LanguageModelProviderId {
798 fn from(value: String) -> Self {
799 Self(SharedString::from(value))
800 }
801}
802
803impl From<String> for LanguageModelProviderName {
804 fn from(value: String) -> Self {
805 Self(SharedString::from(value))
806 }
807}
808
809impl From<Arc<str>> for LanguageModelProviderId {
810 fn from(value: Arc<str>) -> Self {
811 Self(SharedString::from(value))
812 }
813}
814
815impl From<Arc<str>> for LanguageModelProviderName {
816 fn from(value: Arc<str>) -> Self {
817 Self(SharedString::from(value))
818 }
819}
820
821#[cfg(test)]
822mod tests {
823 use super::*;
824
825 #[test]
826 fn test_from_cloud_failure_with_upstream_http_error() {
827 let error = LanguageModelCompletionError::from_cloud_failure(
828 String::from("anthropic").into(),
829 "upstream_http_error".to_string(),
830 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
831 None,
832 );
833
834 match error {
835 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
836 assert_eq!(provider.0, "anthropic");
837 }
838 _ => panic!(
839 "Expected ServerOverloaded error for 503 status, got: {:?}",
840 error
841 ),
842 }
843
844 let error = LanguageModelCompletionError::from_cloud_failure(
845 String::from("anthropic").into(),
846 "upstream_http_error".to_string(),
847 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
848 None,
849 );
850
851 match error {
852 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
853 assert_eq!(provider.0, "anthropic");
854 assert_eq!(message, "Internal server error");
855 }
856 _ => panic!(
857 "Expected ApiInternalServerError for 500 status, got: {:?}",
858 error
859 ),
860 }
861 }
862
863 #[test]
864 fn test_from_cloud_failure_with_standard_format() {
865 let error = LanguageModelCompletionError::from_cloud_failure(
866 String::from("anthropic").into(),
867 "upstream_http_503".to_string(),
868 "Service unavailable".to_string(),
869 None,
870 );
871
872 match error {
873 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
874 assert_eq!(provider.0, "anthropic");
875 }
876 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
877 }
878 }
879
880 #[test]
881 fn test_upstream_http_error_connection_timeout() {
882 let error = LanguageModelCompletionError::from_cloud_failure(
883 String::from("anthropic").into(),
884 "upstream_http_error".to_string(),
885 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
886 None,
887 );
888
889 match error {
890 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
891 assert_eq!(provider.0, "anthropic");
892 }
893 _ => panic!(
894 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
895 error
896 ),
897 }
898
899 let error = LanguageModelCompletionError::from_cloud_failure(
900 String::from("anthropic").into(),
901 "upstream_http_error".to_string(),
902 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
903 None,
904 );
905
906 match error {
907 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
908 assert_eq!(provider.0, "anthropic");
909 assert_eq!(
910 message,
911 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
912 );
913 }
914 _ => panic!(
915 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
916 error
917 ),
918 }
919 }
920}