1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7pub mod tool_schema;
8
9#[cfg(any(test, feature = "test-support"))]
10pub mod fake_provider;
11
12use anthropic::{AnthropicError, parse_prompt_too_long};
13use anyhow::{Result, anyhow};
14use client::Client;
15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
16use futures::FutureExt;
17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
19use http_client::{StatusCode, http};
20use icons::IconName;
21use open_router::OpenRouterError;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24pub use settings::LanguageModelCacheConfiguration;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
40
41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
42 LanguageModelProviderId::new("anthropic");
43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
44 LanguageModelProviderName::new("Anthropic");
45
46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Google AI");
49
50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("OpenAI");
53
54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
56
57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
59 LanguageModelProviderName::new("Zed");
60
61pub fn init(client: Arc<Client>, cx: &mut App) {
62 init_settings(cx);
63 RefreshLlmTokenListener::register(client, cx);
64}
65
66pub fn init_settings(cx: &mut App) {
67 registry::init(cx);
68}
69
70/// A completion event from a language model.
71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
72pub enum LanguageModelCompletionEvent {
73 StatusUpdate(CompletionRequestStatus),
74 Stop(StopReason),
75 Text(String),
76 Thinking {
77 text: String,
78 signature: Option<String>,
79 },
80 RedactedThinking {
81 data: String,
82 },
83 ToolUse(LanguageModelToolUse),
84 ToolUseJsonParseError {
85 id: LanguageModelToolUseId,
86 tool_name: Arc<str>,
87 raw_input: Arc<str>,
88 json_parse_error: String,
89 },
90 StartMessage {
91 message_id: String,
92 },
93 UsageUpdate(TokenUsage),
94}
95
96#[derive(Error, Debug)]
97pub enum LanguageModelCompletionError {
98 #[error("prompt too large for context window")]
99 PromptTooLarge { tokens: Option<u64> },
100 #[error("missing {provider} API key")]
101 NoApiKey { provider: LanguageModelProviderName },
102 #[error("{provider}'s API rate limit exceeded")]
103 RateLimitExceeded {
104 provider: LanguageModelProviderName,
105 retry_after: Option<Duration>,
106 },
107 #[error("{provider}'s API servers are overloaded right now")]
108 ServerOverloaded {
109 provider: LanguageModelProviderName,
110 retry_after: Option<Duration>,
111 },
112 #[error("{provider}'s API server reported an internal server error: {message}")]
113 ApiInternalServerError {
114 provider: LanguageModelProviderName,
115 message: String,
116 },
117 #[error("{message}")]
118 UpstreamProviderError {
119 message: String,
120 status: StatusCode,
121 retry_after: Option<Duration>,
122 },
123 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
124 HttpResponseError {
125 provider: LanguageModelProviderName,
126 status_code: StatusCode,
127 message: String,
128 },
129
130 // Client errors
131 #[error("invalid request format to {provider}'s API: {message}")]
132 BadRequestFormat {
133 provider: LanguageModelProviderName,
134 message: String,
135 },
136 #[error("authentication error with {provider}'s API: {message}")]
137 AuthenticationError {
138 provider: LanguageModelProviderName,
139 message: String,
140 },
141 #[error("Permission error with {provider}'s API: {message}")]
142 PermissionError {
143 provider: LanguageModelProviderName,
144 message: String,
145 },
146 #[error("language model provider API endpoint not found")]
147 ApiEndpointNotFound { provider: LanguageModelProviderName },
148 #[error("I/O error reading response from {provider}'s API")]
149 ApiReadResponseError {
150 provider: LanguageModelProviderName,
151 #[source]
152 error: io::Error,
153 },
154 #[error("error serializing request to {provider} API")]
155 SerializeRequest {
156 provider: LanguageModelProviderName,
157 #[source]
158 error: serde_json::Error,
159 },
160 #[error("error building request body to {provider} API")]
161 BuildRequestBody {
162 provider: LanguageModelProviderName,
163 #[source]
164 error: http::Error,
165 },
166 #[error("error sending HTTP request to {provider} API")]
167 HttpSend {
168 provider: LanguageModelProviderName,
169 #[source]
170 error: anyhow::Error,
171 },
172 #[error("error deserializing {provider} API response")]
173 DeserializeResponse {
174 provider: LanguageModelProviderName,
175 #[source]
176 error: serde_json::Error,
177 },
178
179 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
180 #[error(transparent)]
181 Other(#[from] anyhow::Error),
182}
183
184impl LanguageModelCompletionError {
185 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
186 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
187 let upstream_status = error_json
188 .get("upstream_status")
189 .and_then(|v| v.as_u64())
190 .and_then(|status| u16::try_from(status).ok())
191 .and_then(|status| StatusCode::from_u16(status).ok())?;
192 let inner_message = error_json
193 .get("message")
194 .and_then(|v| v.as_str())
195 .unwrap_or(message)
196 .to_string();
197 Some((upstream_status, inner_message))
198 }
199
200 pub fn from_cloud_failure(
201 upstream_provider: LanguageModelProviderName,
202 code: String,
203 message: String,
204 retry_after: Option<Duration>,
205 ) -> Self {
206 if let Some(tokens) = parse_prompt_too_long(&message) {
207 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
208 // to be reported. This is a temporary workaround to handle this in the case where the
209 // token limit has been exceeded.
210 Self::PromptTooLarge {
211 tokens: Some(tokens),
212 }
213 } else if code == "upstream_http_error" {
214 if let Some((upstream_status, inner_message)) =
215 Self::parse_upstream_error_json(&message)
216 {
217 return Self::from_http_status(
218 upstream_provider,
219 upstream_status,
220 inner_message,
221 retry_after,
222 );
223 }
224 anyhow!("completion request failed, code: {code}, message: {message}").into()
225 } else if let Some(status_code) = code
226 .strip_prefix("upstream_http_")
227 .and_then(|code| StatusCode::from_str(code).ok())
228 {
229 Self::from_http_status(upstream_provider, status_code, message, retry_after)
230 } else if let Some(status_code) = code
231 .strip_prefix("http_")
232 .and_then(|code| StatusCode::from_str(code).ok())
233 {
234 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
235 } else {
236 anyhow!("completion request failed, code: {code}, message: {message}").into()
237 }
238 }
239
240 pub fn from_http_status(
241 provider: LanguageModelProviderName,
242 status_code: StatusCode,
243 message: String,
244 retry_after: Option<Duration>,
245 ) -> Self {
246 match status_code {
247 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
248 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
249 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
250 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
251 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
252 tokens: parse_prompt_too_long(&message),
253 },
254 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
255 provider,
256 retry_after,
257 },
258 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
259 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
260 provider,
261 retry_after,
262 },
263 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
264 provider,
265 retry_after,
266 },
267 _ => Self::HttpResponseError {
268 provider,
269 status_code,
270 message,
271 },
272 }
273 }
274}
275
276impl From<AnthropicError> for LanguageModelCompletionError {
277 fn from(error: AnthropicError) -> Self {
278 let provider = ANTHROPIC_PROVIDER_NAME;
279 match error {
280 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
281 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
282 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
283 AnthropicError::DeserializeResponse(error) => {
284 Self::DeserializeResponse { provider, error }
285 }
286 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
287 AnthropicError::HttpResponseError {
288 status_code,
289 message,
290 } => Self::HttpResponseError {
291 provider,
292 status_code,
293 message,
294 },
295 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
296 provider,
297 retry_after: Some(retry_after),
298 },
299 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
300 provider,
301 retry_after,
302 },
303 AnthropicError::ApiError(api_error) => api_error.into(),
304 }
305 }
306}
307
308impl From<anthropic::ApiError> for LanguageModelCompletionError {
309 fn from(error: anthropic::ApiError) -> Self {
310 use anthropic::ApiErrorCode::*;
311 let provider = ANTHROPIC_PROVIDER_NAME;
312 match error.code() {
313 Some(code) => match code {
314 InvalidRequestError => Self::BadRequestFormat {
315 provider,
316 message: error.message,
317 },
318 AuthenticationError => Self::AuthenticationError {
319 provider,
320 message: error.message,
321 },
322 PermissionError => Self::PermissionError {
323 provider,
324 message: error.message,
325 },
326 NotFoundError => Self::ApiEndpointNotFound { provider },
327 RequestTooLarge => Self::PromptTooLarge {
328 tokens: parse_prompt_too_long(&error.message),
329 },
330 RateLimitError => Self::RateLimitExceeded {
331 provider,
332 retry_after: None,
333 },
334 ApiError => Self::ApiInternalServerError {
335 provider,
336 message: error.message,
337 },
338 OverloadedError => Self::ServerOverloaded {
339 provider,
340 retry_after: None,
341 },
342 },
343 None => Self::Other(error.into()),
344 }
345 }
346}
347
348impl From<open_ai::RequestError> for LanguageModelCompletionError {
349 fn from(error: open_ai::RequestError) -> Self {
350 match error {
351 open_ai::RequestError::HttpResponseError {
352 provider,
353 status_code,
354 body,
355 headers,
356 } => {
357 let retry_after = headers
358 .get(http::header::RETRY_AFTER)
359 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
360 .map(Duration::from_secs);
361
362 Self::from_http_status(provider.into(), status_code, body, retry_after)
363 }
364 open_ai::RequestError::Other(e) => Self::Other(e),
365 }
366 }
367}
368
369impl From<OpenRouterError> for LanguageModelCompletionError {
370 fn from(error: OpenRouterError) -> Self {
371 let provider = LanguageModelProviderName::new("OpenRouter");
372 match error {
373 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
374 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
375 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
376 OpenRouterError::DeserializeResponse(error) => {
377 Self::DeserializeResponse { provider, error }
378 }
379 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
380 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
381 provider,
382 retry_after: Some(retry_after),
383 },
384 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
385 provider,
386 retry_after,
387 },
388 OpenRouterError::ApiError(api_error) => api_error.into(),
389 }
390 }
391}
392
393impl From<open_router::ApiError> for LanguageModelCompletionError {
394 fn from(error: open_router::ApiError) -> Self {
395 use open_router::ApiErrorCode::*;
396 let provider = LanguageModelProviderName::new("OpenRouter");
397 match error.code {
398 InvalidRequestError => Self::BadRequestFormat {
399 provider,
400 message: error.message,
401 },
402 AuthenticationError => Self::AuthenticationError {
403 provider,
404 message: error.message,
405 },
406 PaymentRequiredError => Self::AuthenticationError {
407 provider,
408 message: format!("Payment required: {}", error.message),
409 },
410 PermissionError => Self::PermissionError {
411 provider,
412 message: error.message,
413 },
414 RequestTimedOut => Self::HttpResponseError {
415 provider,
416 status_code: StatusCode::REQUEST_TIMEOUT,
417 message: error.message,
418 },
419 RateLimitError => Self::RateLimitExceeded {
420 provider,
421 retry_after: None,
422 },
423 ApiError => Self::ApiInternalServerError {
424 provider,
425 message: error.message,
426 },
427 OverloadedError => Self::ServerOverloaded {
428 provider,
429 retry_after: None,
430 },
431 }
432 }
433}
434
435#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
436#[serde(rename_all = "snake_case")]
437pub enum StopReason {
438 EndTurn,
439 MaxTokens,
440 ToolUse,
441 Refusal,
442}
443
444#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
445pub struct TokenUsage {
446 #[serde(default, skip_serializing_if = "is_default")]
447 pub input_tokens: u64,
448 #[serde(default, skip_serializing_if = "is_default")]
449 pub output_tokens: u64,
450 #[serde(default, skip_serializing_if = "is_default")]
451 pub cache_creation_input_tokens: u64,
452 #[serde(default, skip_serializing_if = "is_default")]
453 pub cache_read_input_tokens: u64,
454}
455
456impl TokenUsage {
457 pub fn total_tokens(&self) -> u64 {
458 self.input_tokens
459 + self.output_tokens
460 + self.cache_read_input_tokens
461 + self.cache_creation_input_tokens
462 }
463}
464
465impl Add<TokenUsage> for TokenUsage {
466 type Output = Self;
467
468 fn add(self, other: Self) -> Self {
469 Self {
470 input_tokens: self.input_tokens + other.input_tokens,
471 output_tokens: self.output_tokens + other.output_tokens,
472 cache_creation_input_tokens: self.cache_creation_input_tokens
473 + other.cache_creation_input_tokens,
474 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
475 }
476 }
477}
478
479impl Sub<TokenUsage> for TokenUsage {
480 type Output = Self;
481
482 fn sub(self, other: Self) -> Self {
483 Self {
484 input_tokens: self.input_tokens - other.input_tokens,
485 output_tokens: self.output_tokens - other.output_tokens,
486 cache_creation_input_tokens: self.cache_creation_input_tokens
487 - other.cache_creation_input_tokens,
488 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
489 }
490 }
491}
492
493#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
494pub struct LanguageModelToolUseId(Arc<str>);
495
496impl fmt::Display for LanguageModelToolUseId {
497 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
498 write!(f, "{}", self.0)
499 }
500}
501
502impl<T> From<T> for LanguageModelToolUseId
503where
504 T: Into<Arc<str>>,
505{
506 fn from(value: T) -> Self {
507 Self(value.into())
508 }
509}
510
511#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
512pub struct LanguageModelToolUse {
513 pub id: LanguageModelToolUseId,
514 pub name: Arc<str>,
515 pub raw_input: String,
516 pub input: serde_json::Value,
517 pub is_input_complete: bool,
518}
519
520pub struct LanguageModelTextStream {
521 pub message_id: Option<String>,
522 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
523 // Has complete token usage after the stream has finished
524 pub last_token_usage: Arc<Mutex<TokenUsage>>,
525}
526
527impl Default for LanguageModelTextStream {
528 fn default() -> Self {
529 Self {
530 message_id: None,
531 stream: Box::pin(futures::stream::empty()),
532 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
533 }
534 }
535}
536
537pub trait LanguageModel: Send + Sync {
538 fn id(&self) -> LanguageModelId;
539 fn name(&self) -> LanguageModelName;
540 fn provider_id(&self) -> LanguageModelProviderId;
541 fn provider_name(&self) -> LanguageModelProviderName;
542 fn upstream_provider_id(&self) -> LanguageModelProviderId {
543 self.provider_id()
544 }
545 fn upstream_provider_name(&self) -> LanguageModelProviderName {
546 self.provider_name()
547 }
548
549 fn telemetry_id(&self) -> String;
550
551 fn api_key(&self, _cx: &App) -> Option<String> {
552 None
553 }
554
555 /// Whether this model supports images
556 fn supports_images(&self) -> bool;
557
558 /// Whether this model supports tools.
559 fn supports_tools(&self) -> bool;
560
561 /// Whether this model supports choosing which tool to use.
562 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
563
564 /// Returns whether this model supports "burn mode";
565 fn supports_burn_mode(&self) -> bool {
566 false
567 }
568
569 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
570 LanguageModelToolSchemaFormat::JsonSchema
571 }
572
573 fn max_token_count(&self) -> u64;
574 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
575 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
576 None
577 }
578 fn max_output_tokens(&self) -> Option<u64> {
579 None
580 }
581
582 fn count_tokens(
583 &self,
584 request: LanguageModelRequest,
585 cx: &App,
586 ) -> BoxFuture<'static, Result<u64>>;
587
588 fn stream_completion(
589 &self,
590 request: LanguageModelRequest,
591 cx: &AsyncApp,
592 ) -> BoxFuture<
593 'static,
594 Result<
595 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
596 LanguageModelCompletionError,
597 >,
598 >;
599
600 fn stream_completion_text(
601 &self,
602 request: LanguageModelRequest,
603 cx: &AsyncApp,
604 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
605 let future = self.stream_completion(request, cx);
606
607 async move {
608 let events = future.await?;
609 let mut events = events.fuse();
610 let mut message_id = None;
611 let mut first_item_text = None;
612 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
613
614 if let Some(first_event) = events.next().await {
615 match first_event {
616 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
617 message_id = Some(id);
618 }
619 Ok(LanguageModelCompletionEvent::Text(text)) => {
620 first_item_text = Some(text);
621 }
622 _ => (),
623 }
624 }
625
626 let stream = futures::stream::iter(first_item_text.map(Ok))
627 .chain(events.filter_map({
628 let last_token_usage = last_token_usage.clone();
629 move |result| {
630 let last_token_usage = last_token_usage.clone();
631 async move {
632 match result {
633 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
634 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
635 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
636 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
637 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
638 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
639 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
640 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
641 ..
642 }) => None,
643 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
644 *last_token_usage.lock() = token_usage;
645 None
646 }
647 Err(err) => Some(Err(err)),
648 }
649 }
650 }
651 }))
652 .boxed();
653
654 Ok(LanguageModelTextStream {
655 message_id,
656 stream,
657 last_token_usage,
658 })
659 }
660 .boxed()
661 }
662
663 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
664 None
665 }
666
667 #[cfg(any(test, feature = "test-support"))]
668 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
669 unimplemented!()
670 }
671}
672
673pub trait LanguageModelExt: LanguageModel {
674 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
675 match mode {
676 CompletionMode::Normal => self.max_token_count(),
677 CompletionMode::Max => self
678 .max_token_count_in_burn_mode()
679 .unwrap_or_else(|| self.max_token_count()),
680 }
681 }
682}
683impl LanguageModelExt for dyn LanguageModel {}
684
685/// An error that occurred when trying to authenticate the language model provider.
686#[derive(Debug, Error)]
687pub enum AuthenticateError {
688 #[error("connection refused")]
689 ConnectionRefused,
690 #[error("credentials not found")]
691 CredentialsNotFound,
692 #[error(transparent)]
693 Other(#[from] anyhow::Error),
694}
695
696pub trait LanguageModelProvider: 'static {
697 fn id(&self) -> LanguageModelProviderId;
698 fn name(&self) -> LanguageModelProviderName;
699 fn icon(&self) -> IconName {
700 IconName::ZedAssistant
701 }
702 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
703 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
704 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
705 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
706 Vec::new()
707 }
708 fn is_authenticated(&self, cx: &App) -> bool;
709 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
710 fn configuration_view(
711 &self,
712 target_agent: ConfigurationViewTargetAgent,
713 window: &mut Window,
714 cx: &mut App,
715 ) -> AnyView;
716 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
717}
718
719#[derive(Default, Clone)]
720pub enum ConfigurationViewTargetAgent {
721 #[default]
722 ZedAgent,
723 Other(SharedString),
724}
725
726#[derive(PartialEq, Eq)]
727pub enum LanguageModelProviderTosView {
728 /// When there are some past interactions in the Agent Panel.
729 ThreadEmptyState,
730 /// When there are no past interactions in the Agent Panel.
731 ThreadFreshStart,
732 TextThreadPopup,
733 Configuration,
734}
735
736pub trait LanguageModelProviderState: 'static {
737 type ObservableEntity;
738
739 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
740
741 fn subscribe<T: 'static>(
742 &self,
743 cx: &mut gpui::Context<T>,
744 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
745 ) -> Option<gpui::Subscription> {
746 let entity = self.observable_entity()?;
747 Some(cx.observe(&entity, move |this, _, cx| {
748 callback(this, cx);
749 }))
750 }
751}
752
753#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
754pub struct LanguageModelId(pub SharedString);
755
756#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
757pub struct LanguageModelName(pub SharedString);
758
759#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
760pub struct LanguageModelProviderId(pub SharedString);
761
762#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
763pub struct LanguageModelProviderName(pub SharedString);
764
765impl LanguageModelProviderId {
766 pub const fn new(id: &'static str) -> Self {
767 Self(SharedString::new_static(id))
768 }
769}
770
771impl LanguageModelProviderName {
772 pub const fn new(id: &'static str) -> Self {
773 Self(SharedString::new_static(id))
774 }
775}
776
777impl fmt::Display for LanguageModelProviderId {
778 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
779 write!(f, "{}", self.0)
780 }
781}
782
783impl fmt::Display for LanguageModelProviderName {
784 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
785 write!(f, "{}", self.0)
786 }
787}
788
789impl From<String> for LanguageModelId {
790 fn from(value: String) -> Self {
791 Self(SharedString::from(value))
792 }
793}
794
795impl From<String> for LanguageModelName {
796 fn from(value: String) -> Self {
797 Self(SharedString::from(value))
798 }
799}
800
801impl From<String> for LanguageModelProviderId {
802 fn from(value: String) -> Self {
803 Self(SharedString::from(value))
804 }
805}
806
807impl From<String> for LanguageModelProviderName {
808 fn from(value: String) -> Self {
809 Self(SharedString::from(value))
810 }
811}
812
813impl From<Arc<str>> for LanguageModelProviderId {
814 fn from(value: Arc<str>) -> Self {
815 Self(SharedString::from(value))
816 }
817}
818
819impl From<Arc<str>> for LanguageModelProviderName {
820 fn from(value: Arc<str>) -> Self {
821 Self(SharedString::from(value))
822 }
823}
824
825#[cfg(test)]
826mod tests {
827 use super::*;
828
829 #[test]
830 fn test_from_cloud_failure_with_upstream_http_error() {
831 let error = LanguageModelCompletionError::from_cloud_failure(
832 String::from("anthropic").into(),
833 "upstream_http_error".to_string(),
834 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
835 None,
836 );
837
838 match error {
839 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
840 assert_eq!(provider.0, "anthropic");
841 }
842 _ => panic!(
843 "Expected ServerOverloaded error for 503 status, got: {:?}",
844 error
845 ),
846 }
847
848 let error = LanguageModelCompletionError::from_cloud_failure(
849 String::from("anthropic").into(),
850 "upstream_http_error".to_string(),
851 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
852 None,
853 );
854
855 match error {
856 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
857 assert_eq!(provider.0, "anthropic");
858 assert_eq!(message, "Internal server error");
859 }
860 _ => panic!(
861 "Expected ApiInternalServerError for 500 status, got: {:?}",
862 error
863 ),
864 }
865 }
866
867 #[test]
868 fn test_from_cloud_failure_with_standard_format() {
869 let error = LanguageModelCompletionError::from_cloud_failure(
870 String::from("anthropic").into(),
871 "upstream_http_503".to_string(),
872 "Service unavailable".to_string(),
873 None,
874 );
875
876 match error {
877 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
878 assert_eq!(provider.0, "anthropic");
879 }
880 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
881 }
882 }
883
884 #[test]
885 fn test_upstream_http_error_connection_timeout() {
886 let error = LanguageModelCompletionError::from_cloud_failure(
887 String::from("anthropic").into(),
888 "upstream_http_error".to_string(),
889 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
890 None,
891 );
892
893 match error {
894 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
895 assert_eq!(provider.0, "anthropic");
896 }
897 _ => panic!(
898 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
899 error
900 ),
901 }
902
903 let error = LanguageModelCompletionError::from_cloud_failure(
904 String::from("anthropic").into(),
905 "upstream_http_error".to_string(),
906 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
907 None,
908 );
909
910 match error {
911 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
912 assert_eq!(provider.0, "anthropic");
913 assert_eq!(
914 message,
915 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
916 );
917 }
918 _ => panic!(
919 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
920 error
921 ),
922 }
923 }
924}