1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::{Result, anyhow};
13use client::Client;
14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
15use futures::FutureExt;
16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
18use http_client::{StatusCode, http};
19use icons::IconName;
20use open_router::OpenRouterError;
21use parking_lot::Mutex;
22use serde::{Deserialize, Serialize};
23pub use settings::LanguageModelCacheConfiguration;
24use std::fmt::Debug;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39
40pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
41 LanguageModelProviderId::new("anthropic");
42pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
43 LanguageModelProviderName::new("Anthropic");
44
45pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
46pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
47 LanguageModelProviderName::new("Google AI");
48
49pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
50pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
51 LanguageModelProviderName::new("OpenAI");
52
53pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
54pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
55
56pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
57pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
58 LanguageModelProviderName::new("Zed");
59
60pub fn init(client: Arc<Client>, cx: &mut App) {
61 init_settings(cx);
62 RefreshLlmTokenListener::register(client, cx);
63}
64
65pub fn init_settings(cx: &mut App) {
66 registry::init(cx);
67}
68
69/// A completion event from a language model.
70#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
71pub enum LanguageModelCompletionEvent {
72 StatusUpdate(CompletionRequestStatus),
73 Stop(StopReason),
74 Text(String),
75 Thinking {
76 text: String,
77 signature: Option<String>,
78 },
79 RedactedThinking {
80 data: String,
81 },
82 ToolUse(LanguageModelToolUse),
83 ToolUseJsonParseError {
84 id: LanguageModelToolUseId,
85 tool_name: Arc<str>,
86 raw_input: Arc<str>,
87 json_parse_error: String,
88 },
89 StartMessage {
90 message_id: String,
91 },
92 UsageUpdate(TokenUsage),
93}
94
95#[derive(Error, Debug)]
96pub enum LanguageModelCompletionError {
97 #[error("prompt too large for context window")]
98 PromptTooLarge { tokens: Option<u64> },
99 #[error("missing {provider} API key")]
100 NoApiKey { provider: LanguageModelProviderName },
101 #[error("{provider}'s API rate limit exceeded")]
102 RateLimitExceeded {
103 provider: LanguageModelProviderName,
104 retry_after: Option<Duration>,
105 },
106 #[error("{provider}'s API servers are overloaded right now")]
107 ServerOverloaded {
108 provider: LanguageModelProviderName,
109 retry_after: Option<Duration>,
110 },
111 #[error("{provider}'s API server reported an internal server error: {message}")]
112 ApiInternalServerError {
113 provider: LanguageModelProviderName,
114 message: String,
115 },
116 #[error("{message}")]
117 UpstreamProviderError {
118 message: String,
119 status: StatusCode,
120 retry_after: Option<Duration>,
121 },
122 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
123 HttpResponseError {
124 provider: LanguageModelProviderName,
125 status_code: StatusCode,
126 message: String,
127 },
128
129 // Client errors
130 #[error("invalid request format to {provider}'s API: {message}")]
131 BadRequestFormat {
132 provider: LanguageModelProviderName,
133 message: String,
134 },
135 #[error("authentication error with {provider}'s API: {message}")]
136 AuthenticationError {
137 provider: LanguageModelProviderName,
138 message: String,
139 },
140 #[error("permission error with {provider}'s API: {message}")]
141 PermissionError {
142 provider: LanguageModelProviderName,
143 message: String,
144 },
145 #[error("language model provider API endpoint not found")]
146 ApiEndpointNotFound { provider: LanguageModelProviderName },
147 #[error("I/O error reading response from {provider}'s API")]
148 ApiReadResponseError {
149 provider: LanguageModelProviderName,
150 #[source]
151 error: io::Error,
152 },
153 #[error("error serializing request to {provider} API")]
154 SerializeRequest {
155 provider: LanguageModelProviderName,
156 #[source]
157 error: serde_json::Error,
158 },
159 #[error("error building request body to {provider} API")]
160 BuildRequestBody {
161 provider: LanguageModelProviderName,
162 #[source]
163 error: http::Error,
164 },
165 #[error("error sending HTTP request to {provider} API")]
166 HttpSend {
167 provider: LanguageModelProviderName,
168 #[source]
169 error: anyhow::Error,
170 },
171 #[error("error deserializing {provider} API response")]
172 DeserializeResponse {
173 provider: LanguageModelProviderName,
174 #[source]
175 error: serde_json::Error,
176 },
177
178 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
179 #[error(transparent)]
180 Other(#[from] anyhow::Error),
181}
182
183impl LanguageModelCompletionError {
184 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
185 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
186 let upstream_status = error_json
187 .get("upstream_status")
188 .and_then(|v| v.as_u64())
189 .and_then(|status| u16::try_from(status).ok())
190 .and_then(|status| StatusCode::from_u16(status).ok())?;
191 let inner_message = error_json
192 .get("message")
193 .and_then(|v| v.as_str())
194 .unwrap_or(message)
195 .to_string();
196 Some((upstream_status, inner_message))
197 }
198
199 pub fn from_cloud_failure(
200 upstream_provider: LanguageModelProviderName,
201 code: String,
202 message: String,
203 retry_after: Option<Duration>,
204 ) -> Self {
205 if let Some(tokens) = parse_prompt_too_long(&message) {
206 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
207 // to be reported. This is a temporary workaround to handle this in the case where the
208 // token limit has been exceeded.
209 Self::PromptTooLarge {
210 tokens: Some(tokens),
211 }
212 } else if code == "upstream_http_error" {
213 if let Some((upstream_status, inner_message)) =
214 Self::parse_upstream_error_json(&message)
215 {
216 return Self::from_http_status(
217 upstream_provider,
218 upstream_status,
219 inner_message,
220 retry_after,
221 );
222 }
223 anyhow!("completion request failed, code: {code}, message: {message}").into()
224 } else if let Some(status_code) = code
225 .strip_prefix("upstream_http_")
226 .and_then(|code| StatusCode::from_str(code).ok())
227 {
228 Self::from_http_status(upstream_provider, status_code, message, retry_after)
229 } else if let Some(status_code) = code
230 .strip_prefix("http_")
231 .and_then(|code| StatusCode::from_str(code).ok())
232 {
233 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
234 } else {
235 anyhow!("completion request failed, code: {code}, message: {message}").into()
236 }
237 }
238
239 pub fn from_http_status(
240 provider: LanguageModelProviderName,
241 status_code: StatusCode,
242 message: String,
243 retry_after: Option<Duration>,
244 ) -> Self {
245 match status_code {
246 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
247 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
248 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
249 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
250 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
251 tokens: parse_prompt_too_long(&message),
252 },
253 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
254 provider,
255 retry_after,
256 },
257 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
258 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
259 provider,
260 retry_after,
261 },
262 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
263 provider,
264 retry_after,
265 },
266 _ => Self::HttpResponseError {
267 provider,
268 status_code,
269 message,
270 },
271 }
272 }
273}
274
275impl From<AnthropicError> for LanguageModelCompletionError {
276 fn from(error: AnthropicError) -> Self {
277 let provider = ANTHROPIC_PROVIDER_NAME;
278 match error {
279 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
280 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
281 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
282 AnthropicError::DeserializeResponse(error) => {
283 Self::DeserializeResponse { provider, error }
284 }
285 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
286 AnthropicError::HttpResponseError {
287 status_code,
288 message,
289 } => Self::HttpResponseError {
290 provider,
291 status_code,
292 message,
293 },
294 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
295 provider,
296 retry_after: Some(retry_after),
297 },
298 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
299 provider,
300 retry_after,
301 },
302 AnthropicError::ApiError(api_error) => api_error.into(),
303 }
304 }
305}
306
307impl From<anthropic::ApiError> for LanguageModelCompletionError {
308 fn from(error: anthropic::ApiError) -> Self {
309 use anthropic::ApiErrorCode::*;
310 let provider = ANTHROPIC_PROVIDER_NAME;
311 match error.code() {
312 Some(code) => match code {
313 InvalidRequestError => Self::BadRequestFormat {
314 provider,
315 message: error.message,
316 },
317 AuthenticationError => Self::AuthenticationError {
318 provider,
319 message: error.message,
320 },
321 PermissionError => Self::PermissionError {
322 provider,
323 message: error.message,
324 },
325 NotFoundError => Self::ApiEndpointNotFound { provider },
326 RequestTooLarge => Self::PromptTooLarge {
327 tokens: parse_prompt_too_long(&error.message),
328 },
329 RateLimitError => Self::RateLimitExceeded {
330 provider,
331 retry_after: None,
332 },
333 ApiError => Self::ApiInternalServerError {
334 provider,
335 message: error.message,
336 },
337 OverloadedError => Self::ServerOverloaded {
338 provider,
339 retry_after: None,
340 },
341 },
342 None => Self::Other(error.into()),
343 }
344 }
345}
346
347impl From<OpenRouterError> for LanguageModelCompletionError {
348 fn from(error: OpenRouterError) -> Self {
349 let provider = LanguageModelProviderName::new("OpenRouter");
350 match error {
351 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
352 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
353 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
354 OpenRouterError::DeserializeResponse(error) => {
355 Self::DeserializeResponse { provider, error }
356 }
357 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
358 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
359 provider,
360 retry_after: Some(retry_after),
361 },
362 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
363 provider,
364 retry_after,
365 },
366 OpenRouterError::ApiError(api_error) => api_error.into(),
367 }
368 }
369}
370
371impl From<open_router::ApiError> for LanguageModelCompletionError {
372 fn from(error: open_router::ApiError) -> Self {
373 use open_router::ApiErrorCode::*;
374 let provider = LanguageModelProviderName::new("OpenRouter");
375 match error.code {
376 InvalidRequestError => Self::BadRequestFormat {
377 provider,
378 message: error.message,
379 },
380 AuthenticationError => Self::AuthenticationError {
381 provider,
382 message: error.message,
383 },
384 PaymentRequiredError => Self::AuthenticationError {
385 provider,
386 message: format!("Payment required: {}", error.message),
387 },
388 PermissionError => Self::PermissionError {
389 provider,
390 message: error.message,
391 },
392 RequestTimedOut => Self::HttpResponseError {
393 provider,
394 status_code: StatusCode::REQUEST_TIMEOUT,
395 message: error.message,
396 },
397 RateLimitError => Self::RateLimitExceeded {
398 provider,
399 retry_after: None,
400 },
401 ApiError => Self::ApiInternalServerError {
402 provider,
403 message: error.message,
404 },
405 OverloadedError => Self::ServerOverloaded {
406 provider,
407 retry_after: None,
408 },
409 }
410 }
411}
412
413/// Indicates the format used to define the input schema for a language model tool.
414#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
415pub enum LanguageModelToolSchemaFormat {
416 /// A JSON schema, see https://json-schema.org
417 JsonSchema,
418 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
419 JsonSchemaSubset,
420}
421
422#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
423#[serde(rename_all = "snake_case")]
424pub enum StopReason {
425 EndTurn,
426 MaxTokens,
427 ToolUse,
428 Refusal,
429}
430
431#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
432pub struct TokenUsage {
433 #[serde(default, skip_serializing_if = "is_default")]
434 pub input_tokens: u64,
435 #[serde(default, skip_serializing_if = "is_default")]
436 pub output_tokens: u64,
437 #[serde(default, skip_serializing_if = "is_default")]
438 pub cache_creation_input_tokens: u64,
439 #[serde(default, skip_serializing_if = "is_default")]
440 pub cache_read_input_tokens: u64,
441}
442
443impl TokenUsage {
444 pub fn total_tokens(&self) -> u64 {
445 self.input_tokens
446 + self.output_tokens
447 + self.cache_read_input_tokens
448 + self.cache_creation_input_tokens
449 }
450}
451
452impl Add<TokenUsage> for TokenUsage {
453 type Output = Self;
454
455 fn add(self, other: Self) -> Self {
456 Self {
457 input_tokens: self.input_tokens + other.input_tokens,
458 output_tokens: self.output_tokens + other.output_tokens,
459 cache_creation_input_tokens: self.cache_creation_input_tokens
460 + other.cache_creation_input_tokens,
461 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
462 }
463 }
464}
465
466impl Sub<TokenUsage> for TokenUsage {
467 type Output = Self;
468
469 fn sub(self, other: Self) -> Self {
470 Self {
471 input_tokens: self.input_tokens - other.input_tokens,
472 output_tokens: self.output_tokens - other.output_tokens,
473 cache_creation_input_tokens: self.cache_creation_input_tokens
474 - other.cache_creation_input_tokens,
475 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
476 }
477 }
478}
479
480#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
481pub struct LanguageModelToolUseId(Arc<str>);
482
483impl fmt::Display for LanguageModelToolUseId {
484 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
485 write!(f, "{}", self.0)
486 }
487}
488
489impl<T> From<T> for LanguageModelToolUseId
490where
491 T: Into<Arc<str>>,
492{
493 fn from(value: T) -> Self {
494 Self(value.into())
495 }
496}
497
498#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
499pub struct LanguageModelToolUse {
500 pub id: LanguageModelToolUseId,
501 pub name: Arc<str>,
502 pub raw_input: String,
503 pub input: serde_json::Value,
504 pub is_input_complete: bool,
505}
506
507pub struct LanguageModelTextStream {
508 pub message_id: Option<String>,
509 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
510 // Has complete token usage after the stream has finished
511 pub last_token_usage: Arc<Mutex<TokenUsage>>,
512}
513
514impl Debug for LanguageModelTextStream {
515 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
516 f.debug_struct("LanguageModelTextStream")
517 .field("message_id", &self.message_id)
518 .finish()
519 }
520}
521
522impl Default for LanguageModelTextStream {
523 fn default() -> Self {
524 Self {
525 message_id: None,
526 stream: Box::pin(futures::stream::empty()),
527 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
528 }
529 }
530}
531
532pub trait LanguageModel: Send + Sync {
533 fn id(&self) -> LanguageModelId;
534 fn name(&self) -> LanguageModelName;
535 fn provider_id(&self) -> LanguageModelProviderId;
536 fn provider_name(&self) -> LanguageModelProviderName;
537 fn upstream_provider_id(&self) -> LanguageModelProviderId {
538 self.provider_id()
539 }
540 fn upstream_provider_name(&self) -> LanguageModelProviderName {
541 self.provider_name()
542 }
543
544 fn telemetry_id(&self) -> String;
545
546 fn api_key(&self, _cx: &App) -> Option<String> {
547 None
548 }
549
550 /// Whether this model supports images
551 fn supports_images(&self) -> bool;
552
553 /// Whether this model supports tools.
554 fn supports_tools(&self) -> bool;
555
556 /// Whether this model supports choosing which tool to use.
557 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
558
559 /// Returns whether this model supports "burn mode";
560 fn supports_burn_mode(&self) -> bool {
561 false
562 }
563
564 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
565 LanguageModelToolSchemaFormat::JsonSchema
566 }
567
568 fn max_token_count(&self) -> u64;
569 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
570 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
571 None
572 }
573 fn max_output_tokens(&self) -> Option<u64> {
574 None
575 }
576
577 fn count_tokens(
578 &self,
579 request: LanguageModelRequest,
580 cx: &App,
581 ) -> BoxFuture<'static, Result<u64>>;
582
583 fn stream_completion(
584 &self,
585 request: LanguageModelRequest,
586 cx: &AsyncApp,
587 ) -> BoxFuture<
588 'static,
589 Result<
590 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
591 LanguageModelCompletionError,
592 >,
593 >;
594
595 fn stream_completion_text(
596 &self,
597 request: LanguageModelRequest,
598 cx: &AsyncApp,
599 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
600 let future = self.stream_completion(request, cx);
601
602 async move {
603 let events = future.await?;
604 let mut events = events.fuse();
605 let mut message_id = None;
606 let mut first_item_text = None;
607 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
608
609 if let Some(first_event) = events.next().await {
610 match first_event {
611 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
612 message_id = Some(id);
613 }
614 Ok(LanguageModelCompletionEvent::Text(text)) => {
615 first_item_text = Some(text);
616 }
617 _ => (),
618 }
619 }
620
621 let stream = futures::stream::iter(first_item_text.map(Ok))
622 .chain(events.filter_map({
623 let last_token_usage = last_token_usage.clone();
624 move |result| {
625 let last_token_usage = last_token_usage.clone();
626 async move {
627 match result {
628 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
629 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
630 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
631 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
632 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
633 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
634 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
635 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
636 ..
637 }) => None,
638 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
639 *last_token_usage.lock() = token_usage;
640 None
641 }
642 Err(err) => Some(Err(err)),
643 }
644 }
645 }
646 }))
647 .boxed();
648
649 Ok(LanguageModelTextStream {
650 message_id,
651 stream,
652 last_token_usage,
653 })
654 }
655 .boxed()
656 }
657
658 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
659 None
660 }
661
662 #[cfg(any(test, feature = "test-support"))]
663 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
664 unimplemented!()
665 }
666}
667
668pub trait LanguageModelExt: LanguageModel {
669 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
670 match mode {
671 CompletionMode::Normal => self.max_token_count(),
672 CompletionMode::Max => self
673 .max_token_count_in_burn_mode()
674 .unwrap_or_else(|| self.max_token_count()),
675 }
676 }
677}
678impl LanguageModelExt for dyn LanguageModel {}
679
680/// An error that occurred when trying to authenticate the language model provider.
681#[derive(Debug, Error)]
682pub enum AuthenticateError {
683 #[error("connection refused")]
684 ConnectionRefused,
685 #[error("credentials not found")]
686 CredentialsNotFound,
687 #[error(transparent)]
688 Other(#[from] anyhow::Error),
689}
690
691pub trait LanguageModelProvider: 'static {
692 fn id(&self) -> LanguageModelProviderId;
693 fn name(&self) -> LanguageModelProviderName;
694 fn icon(&self) -> IconName {
695 IconName::ZedAssistant
696 }
697 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
698 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
699 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
700 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
701 Vec::new()
702 }
703 fn is_authenticated(&self, cx: &App) -> bool;
704 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
705 fn configuration_view(
706 &self,
707 target_agent: ConfigurationViewTargetAgent,
708 window: &mut Window,
709 cx: &mut App,
710 ) -> AnyView;
711 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
712}
713
714#[derive(Default, Clone)]
715pub enum ConfigurationViewTargetAgent {
716 #[default]
717 ZedAgent,
718 Other(SharedString),
719}
720
721#[derive(PartialEq, Eq)]
722pub enum LanguageModelProviderTosView {
723 /// When there are some past interactions in the Agent Panel.
724 ThreadEmptyState,
725 /// When there are no past interactions in the Agent Panel.
726 ThreadFreshStart,
727 TextThreadPopup,
728 Configuration,
729}
730
731pub trait LanguageModelProviderState: 'static {
732 type ObservableEntity;
733
734 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
735
736 fn subscribe<T: 'static>(
737 &self,
738 cx: &mut gpui::Context<T>,
739 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
740 ) -> Option<gpui::Subscription> {
741 let entity = self.observable_entity()?;
742 Some(cx.observe(&entity, move |this, _, cx| {
743 callback(this, cx);
744 }))
745 }
746}
747
748#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
749pub struct LanguageModelId(pub SharedString);
750
751#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
752pub struct LanguageModelName(pub SharedString);
753
754#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
755pub struct LanguageModelProviderId(pub SharedString);
756
757#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
758pub struct LanguageModelProviderName(pub SharedString);
759
760impl LanguageModelProviderId {
761 pub const fn new(id: &'static str) -> Self {
762 Self(SharedString::new_static(id))
763 }
764}
765
766impl LanguageModelProviderName {
767 pub const fn new(id: &'static str) -> Self {
768 Self(SharedString::new_static(id))
769 }
770}
771
772impl fmt::Display for LanguageModelProviderId {
773 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
774 write!(f, "{}", self.0)
775 }
776}
777
778impl fmt::Display for LanguageModelProviderName {
779 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
780 write!(f, "{}", self.0)
781 }
782}
783
784impl From<String> for LanguageModelId {
785 fn from(value: String) -> Self {
786 Self(SharedString::from(value))
787 }
788}
789
790impl From<String> for LanguageModelName {
791 fn from(value: String) -> Self {
792 Self(SharedString::from(value))
793 }
794}
795
796impl From<String> for LanguageModelProviderId {
797 fn from(value: String) -> Self {
798 Self(SharedString::from(value))
799 }
800}
801
802impl From<String> for LanguageModelProviderName {
803 fn from(value: String) -> Self {
804 Self(SharedString::from(value))
805 }
806}
807
808impl From<Arc<str>> for LanguageModelProviderId {
809 fn from(value: Arc<str>) -> Self {
810 Self(SharedString::from(value))
811 }
812}
813
814impl From<Arc<str>> for LanguageModelProviderName {
815 fn from(value: Arc<str>) -> Self {
816 Self(SharedString::from(value))
817 }
818}
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823
824 #[test]
825 fn test_from_cloud_failure_with_upstream_http_error() {
826 let error = LanguageModelCompletionError::from_cloud_failure(
827 String::from("anthropic").into(),
828 "upstream_http_error".to_string(),
829 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
830 None,
831 );
832
833 match error {
834 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
835 assert_eq!(provider.0, "anthropic");
836 }
837 _ => panic!(
838 "Expected ServerOverloaded error for 503 status, got: {:?}",
839 error
840 ),
841 }
842
843 let error = LanguageModelCompletionError::from_cloud_failure(
844 String::from("anthropic").into(),
845 "upstream_http_error".to_string(),
846 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
847 None,
848 );
849
850 match error {
851 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
852 assert_eq!(provider.0, "anthropic");
853 assert_eq!(message, "Internal server error");
854 }
855 _ => panic!(
856 "Expected ApiInternalServerError for 500 status, got: {:?}",
857 error
858 ),
859 }
860 }
861
862 #[test]
863 fn test_from_cloud_failure_with_standard_format() {
864 let error = LanguageModelCompletionError::from_cloud_failure(
865 String::from("anthropic").into(),
866 "upstream_http_503".to_string(),
867 "Service unavailable".to_string(),
868 None,
869 );
870
871 match error {
872 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
873 assert_eq!(provider.0, "anthropic");
874 }
875 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
876 }
877 }
878
879 #[test]
880 fn test_upstream_http_error_connection_timeout() {
881 let error = LanguageModelCompletionError::from_cloud_failure(
882 String::from("anthropic").into(),
883 "upstream_http_error".to_string(),
884 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
885 None,
886 );
887
888 match error {
889 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
890 assert_eq!(provider.0, "anthropic");
891 }
892 _ => panic!(
893 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
894 error
895 ),
896 }
897
898 let error = LanguageModelCompletionError::from_cloud_failure(
899 String::from("anthropic").into(),
900 "upstream_http_error".to_string(),
901 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
902 None,
903 );
904
905 match error {
906 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
907 assert_eq!(provider.0, "anthropic");
908 assert_eq!(
909 message,
910 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
911 );
912 }
913 _ => panic!(
914 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
915 error
916 ),
917 }
918 }
919}