1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::{Result, anyhow};
13use client::Client;
14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
15use futures::FutureExt;
16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
18use http_client::{StatusCode, http};
19use icons::IconName;
20use open_router::OpenRouterError;
21use parking_lot::Mutex;
22use serde::{Deserialize, Serialize};
23pub use settings::LanguageModelCacheConfiguration;
24use std::ops::{Add, Sub};
25use std::str::FromStr;
26use std::sync::Arc;
27use std::time::Duration;
28use std::{fmt, io};
29use thiserror::Error;
30use util::serde::is_default;
31
32pub use crate::model::*;
33pub use crate::rate_limiter::*;
34pub use crate::registry::*;
35pub use crate::request::*;
36pub use crate::role::*;
37pub use crate::telemetry::*;
38
39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
40 LanguageModelProviderId::new("anthropic");
41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
42 LanguageModelProviderName::new("Anthropic");
43
44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
46 LanguageModelProviderName::new("Google AI");
47
48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
50 LanguageModelProviderName::new("OpenAI");
51
52pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
53pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
54
55pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
56pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
57 LanguageModelProviderName::new("Zed");
58
59pub const OLLAMA_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
60pub const OLLAMA_PROVIDER_NAME: LanguageModelProviderName =
61 LanguageModelProviderName::new("Ollama");
62
63pub fn init(client: Arc<Client>, cx: &mut App) {
64 init_settings(cx);
65 RefreshLlmTokenListener::register(client, cx);
66}
67
68pub fn init_settings(cx: &mut App) {
69 registry::init(cx);
70}
71
72/// A completion event from a language model.
73#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
74pub enum LanguageModelCompletionEvent {
75 StatusUpdate(CompletionRequestStatus),
76 Stop(StopReason),
77 Text(String),
78 Thinking {
79 text: String,
80 signature: Option<String>,
81 },
82 RedactedThinking {
83 data: String,
84 },
85 ToolUse(LanguageModelToolUse),
86 ToolUseJsonParseError {
87 id: LanguageModelToolUseId,
88 tool_name: Arc<str>,
89 raw_input: Arc<str>,
90 json_parse_error: String,
91 },
92 StartMessage {
93 message_id: String,
94 },
95 UsageUpdate(TokenUsage),
96}
97
98#[derive(Error, Debug)]
99pub enum LanguageModelCompletionError {
100 #[error("prompt too large for context window")]
101 PromptTooLarge { tokens: Option<u64> },
102 #[error("missing {provider} API key")]
103 NoApiKey { provider: LanguageModelProviderName },
104 #[error("{provider}'s API rate limit exceeded")]
105 RateLimitExceeded {
106 provider: LanguageModelProviderName,
107 retry_after: Option<Duration>,
108 },
109 #[error("{provider}'s API servers are overloaded right now")]
110 ServerOverloaded {
111 provider: LanguageModelProviderName,
112 retry_after: Option<Duration>,
113 },
114 #[error("{provider}'s API server reported an internal server error: {message}")]
115 ApiInternalServerError {
116 provider: LanguageModelProviderName,
117 message: String,
118 },
119 #[error("{message}")]
120 UpstreamProviderError {
121 message: String,
122 status: StatusCode,
123 retry_after: Option<Duration>,
124 },
125 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
126 HttpResponseError {
127 provider: LanguageModelProviderName,
128 status_code: StatusCode,
129 message: String,
130 },
131
132 // Client errors
133 #[error("invalid request format to {provider}'s API: {message}")]
134 BadRequestFormat {
135 provider: LanguageModelProviderName,
136 message: String,
137 },
138 #[error("authentication error with {provider}'s API: {message}")]
139 AuthenticationError {
140 provider: LanguageModelProviderName,
141 message: String,
142 },
143 #[error("permission error with {provider}'s API: {message}")]
144 PermissionError {
145 provider: LanguageModelProviderName,
146 message: String,
147 },
148 #[error("language model provider API endpoint not found")]
149 ApiEndpointNotFound { provider: LanguageModelProviderName },
150 #[error("I/O error reading response from {provider}'s API")]
151 ApiReadResponseError {
152 provider: LanguageModelProviderName,
153 #[source]
154 error: io::Error,
155 },
156 #[error("error serializing request to {provider} API")]
157 SerializeRequest {
158 provider: LanguageModelProviderName,
159 #[source]
160 error: serde_json::Error,
161 },
162 #[error("error building request body to {provider} API")]
163 BuildRequestBody {
164 provider: LanguageModelProviderName,
165 #[source]
166 error: http::Error,
167 },
168 #[error("error sending HTTP request to {provider} API")]
169 HttpSend {
170 provider: LanguageModelProviderName,
171 #[source]
172 error: anyhow::Error,
173 },
174 #[error("error deserializing {provider} API response")]
175 DeserializeResponse {
176 provider: LanguageModelProviderName,
177 #[source]
178 error: serde_json::Error,
179 },
180
181 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
182 #[error(transparent)]
183 Other(#[from] anyhow::Error),
184}
185
186impl LanguageModelCompletionError {
187 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
188 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
189 let upstream_status = error_json
190 .get("upstream_status")
191 .and_then(|v| v.as_u64())
192 .and_then(|status| u16::try_from(status).ok())
193 .and_then(|status| StatusCode::from_u16(status).ok())?;
194 let inner_message = error_json
195 .get("message")
196 .and_then(|v| v.as_str())
197 .unwrap_or(message)
198 .to_string();
199 Some((upstream_status, inner_message))
200 }
201
202 pub fn from_cloud_failure(
203 upstream_provider: LanguageModelProviderName,
204 code: String,
205 message: String,
206 retry_after: Option<Duration>,
207 ) -> Self {
208 if let Some(tokens) = parse_prompt_too_long(&message) {
209 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
210 // to be reported. This is a temporary workaround to handle this in the case where the
211 // token limit has been exceeded.
212 Self::PromptTooLarge {
213 tokens: Some(tokens),
214 }
215 } else if code == "upstream_http_error" {
216 if let Some((upstream_status, inner_message)) =
217 Self::parse_upstream_error_json(&message)
218 {
219 return Self::from_http_status(
220 upstream_provider,
221 upstream_status,
222 inner_message,
223 retry_after,
224 );
225 }
226 anyhow!("completion request failed, code: {code}, message: {message}").into()
227 } else if let Some(status_code) = code
228 .strip_prefix("upstream_http_")
229 .and_then(|code| StatusCode::from_str(code).ok())
230 {
231 Self::from_http_status(upstream_provider, status_code, message, retry_after)
232 } else if let Some(status_code) = code
233 .strip_prefix("http_")
234 .and_then(|code| StatusCode::from_str(code).ok())
235 {
236 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
237 } else {
238 anyhow!("completion request failed, code: {code}, message: {message}").into()
239 }
240 }
241
242 pub fn from_http_status(
243 provider: LanguageModelProviderName,
244 status_code: StatusCode,
245 message: String,
246 retry_after: Option<Duration>,
247 ) -> Self {
248 match status_code {
249 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
250 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
251 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
252 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
253 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
254 tokens: parse_prompt_too_long(&message),
255 },
256 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
257 provider,
258 retry_after,
259 },
260 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
261 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
262 provider,
263 retry_after,
264 },
265 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
266 provider,
267 retry_after,
268 },
269 _ => Self::HttpResponseError {
270 provider,
271 status_code,
272 message,
273 },
274 }
275 }
276}
277
278impl From<AnthropicError> for LanguageModelCompletionError {
279 fn from(error: AnthropicError) -> Self {
280 let provider = ANTHROPIC_PROVIDER_NAME;
281 match error {
282 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
283 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
284 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
285 AnthropicError::DeserializeResponse(error) => {
286 Self::DeserializeResponse { provider, error }
287 }
288 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
289 AnthropicError::HttpResponseError {
290 status_code,
291 message,
292 } => Self::HttpResponseError {
293 provider,
294 status_code,
295 message,
296 },
297 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
298 provider,
299 retry_after: Some(retry_after),
300 },
301 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
302 provider,
303 retry_after,
304 },
305 AnthropicError::ApiError(api_error) => api_error.into(),
306 }
307 }
308}
309
310impl From<anthropic::ApiError> for LanguageModelCompletionError {
311 fn from(error: anthropic::ApiError) -> Self {
312 use anthropic::ApiErrorCode::*;
313 let provider = ANTHROPIC_PROVIDER_NAME;
314 match error.code() {
315 Some(code) => match code {
316 InvalidRequestError => Self::BadRequestFormat {
317 provider,
318 message: error.message,
319 },
320 AuthenticationError => Self::AuthenticationError {
321 provider,
322 message: error.message,
323 },
324 PermissionError => Self::PermissionError {
325 provider,
326 message: error.message,
327 },
328 NotFoundError => Self::ApiEndpointNotFound { provider },
329 RequestTooLarge => Self::PromptTooLarge {
330 tokens: parse_prompt_too_long(&error.message),
331 },
332 RateLimitError => Self::RateLimitExceeded {
333 provider,
334 retry_after: None,
335 },
336 ApiError => Self::ApiInternalServerError {
337 provider,
338 message: error.message,
339 },
340 OverloadedError => Self::ServerOverloaded {
341 provider,
342 retry_after: None,
343 },
344 },
345 None => Self::Other(error.into()),
346 }
347 }
348}
349
350impl From<OpenRouterError> for LanguageModelCompletionError {
351 fn from(error: OpenRouterError) -> Self {
352 let provider = LanguageModelProviderName::new("OpenRouter");
353 match error {
354 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
355 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
356 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
357 OpenRouterError::DeserializeResponse(error) => {
358 Self::DeserializeResponse { provider, error }
359 }
360 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
361 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
362 provider,
363 retry_after: Some(retry_after),
364 },
365 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
366 provider,
367 retry_after,
368 },
369 OpenRouterError::ApiError(api_error) => api_error.into(),
370 }
371 }
372}
373
374impl From<open_router::ApiError> for LanguageModelCompletionError {
375 fn from(error: open_router::ApiError) -> Self {
376 use open_router::ApiErrorCode::*;
377 let provider = LanguageModelProviderName::new("OpenRouter");
378 match error.code {
379 InvalidRequestError => Self::BadRequestFormat {
380 provider,
381 message: error.message,
382 },
383 AuthenticationError => Self::AuthenticationError {
384 provider,
385 message: error.message,
386 },
387 PaymentRequiredError => Self::AuthenticationError {
388 provider,
389 message: format!("Payment required: {}", error.message),
390 },
391 PermissionError => Self::PermissionError {
392 provider,
393 message: error.message,
394 },
395 RequestTimedOut => Self::HttpResponseError {
396 provider,
397 status_code: StatusCode::REQUEST_TIMEOUT,
398 message: error.message,
399 },
400 RateLimitError => Self::RateLimitExceeded {
401 provider,
402 retry_after: None,
403 },
404 ApiError => Self::ApiInternalServerError {
405 provider,
406 message: error.message,
407 },
408 OverloadedError => Self::ServerOverloaded {
409 provider,
410 retry_after: None,
411 },
412 }
413 }
414}
415
416/// Indicates the format used to define the input schema for a language model tool.
417#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
418pub enum LanguageModelToolSchemaFormat {
419 /// A JSON schema, see https://json-schema.org
420 JsonSchema,
421 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
422 JsonSchemaSubset,
423}
424
425#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
426#[serde(rename_all = "snake_case")]
427pub enum StopReason {
428 EndTurn,
429 MaxTokens,
430 ToolUse,
431 Refusal,
432}
433
434#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
435pub struct TokenUsage {
436 #[serde(default, skip_serializing_if = "is_default")]
437 pub input_tokens: u64,
438 #[serde(default, skip_serializing_if = "is_default")]
439 pub output_tokens: u64,
440 #[serde(default, skip_serializing_if = "is_default")]
441 pub cache_creation_input_tokens: u64,
442 #[serde(default, skip_serializing_if = "is_default")]
443 pub cache_read_input_tokens: u64,
444}
445
446impl TokenUsage {
447 pub fn total_tokens(&self) -> u64 {
448 self.input_tokens
449 + self.output_tokens
450 + self.cache_read_input_tokens
451 + self.cache_creation_input_tokens
452 }
453}
454
455impl Add<TokenUsage> for TokenUsage {
456 type Output = Self;
457
458 fn add(self, other: Self) -> Self {
459 Self {
460 input_tokens: self.input_tokens + other.input_tokens,
461 output_tokens: self.output_tokens + other.output_tokens,
462 cache_creation_input_tokens: self.cache_creation_input_tokens
463 + other.cache_creation_input_tokens,
464 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
465 }
466 }
467}
468
469impl Sub<TokenUsage> for TokenUsage {
470 type Output = Self;
471
472 fn sub(self, other: Self) -> Self {
473 Self {
474 input_tokens: self.input_tokens - other.input_tokens,
475 output_tokens: self.output_tokens - other.output_tokens,
476 cache_creation_input_tokens: self.cache_creation_input_tokens
477 - other.cache_creation_input_tokens,
478 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
479 }
480 }
481}
482
483#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
484pub struct LanguageModelToolUseId(Arc<str>);
485
486impl fmt::Display for LanguageModelToolUseId {
487 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
488 write!(f, "{}", self.0)
489 }
490}
491
492impl<T> From<T> for LanguageModelToolUseId
493where
494 T: Into<Arc<str>>,
495{
496 fn from(value: T) -> Self {
497 Self(value.into())
498 }
499}
500
501#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
502pub struct LanguageModelToolUse {
503 pub id: LanguageModelToolUseId,
504 pub name: Arc<str>,
505 pub raw_input: String,
506 pub input: serde_json::Value,
507 pub is_input_complete: bool,
508}
509
510pub struct LanguageModelTextStream {
511 pub message_id: Option<String>,
512 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
513 // Has complete token usage after the stream has finished
514 pub last_token_usage: Arc<Mutex<TokenUsage>>,
515}
516
517impl Default for LanguageModelTextStream {
518 fn default() -> Self {
519 Self {
520 message_id: None,
521 stream: Box::pin(futures::stream::empty()),
522 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
523 }
524 }
525}
526
527pub trait LanguageModel: Send + Sync {
528 fn id(&self) -> LanguageModelId;
529 fn name(&self) -> LanguageModelName;
530 fn provider_id(&self) -> LanguageModelProviderId;
531 fn provider_name(&self) -> LanguageModelProviderName;
532 fn upstream_provider_id(&self) -> LanguageModelProviderId {
533 self.provider_id()
534 }
535 fn upstream_provider_name(&self) -> LanguageModelProviderName {
536 self.provider_name()
537 }
538
539 fn telemetry_id(&self) -> String;
540
541 fn api_key(&self, _cx: &App) -> Option<String> {
542 None
543 }
544
545 /// Whether this model supports images
546 fn supports_images(&self) -> bool;
547
548 /// Whether this model supports tools.
549 fn supports_tools(&self) -> bool;
550
551 /// Whether this model supports choosing which tool to use.
552 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
553
554 /// Returns whether this model supports "burn mode";
555 fn supports_burn_mode(&self) -> bool {
556 false
557 }
558
559 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
560 LanguageModelToolSchemaFormat::JsonSchema
561 }
562
563 fn max_token_count(&self) -> u64;
564 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
565 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
566 None
567 }
568 fn max_output_tokens(&self) -> Option<u64> {
569 None
570 }
571
572 fn count_tokens(
573 &self,
574 request: LanguageModelRequest,
575 cx: &App,
576 ) -> BoxFuture<'static, Result<u64>>;
577
578 fn stream_completion(
579 &self,
580 request: LanguageModelRequest,
581 cx: &AsyncApp,
582 ) -> BoxFuture<
583 'static,
584 Result<
585 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
586 LanguageModelCompletionError,
587 >,
588 >;
589
590 fn stream_completion_text(
591 &self,
592 request: LanguageModelRequest,
593 cx: &AsyncApp,
594 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
595 let future = self.stream_completion(request, cx);
596
597 async move {
598 let events = future.await?;
599 let mut events = events.fuse();
600 let mut message_id = None;
601 let mut first_item_text = None;
602 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
603
604 if let Some(first_event) = events.next().await {
605 match first_event {
606 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
607 message_id = Some(id);
608 }
609 Ok(LanguageModelCompletionEvent::Text(text)) => {
610 first_item_text = Some(text);
611 }
612 _ => (),
613 }
614 }
615
616 let stream = futures::stream::iter(first_item_text.map(Ok))
617 .chain(events.filter_map({
618 let last_token_usage = last_token_usage.clone();
619 move |result| {
620 let last_token_usage = last_token_usage.clone();
621 async move {
622 match result {
623 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
624 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
625 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
626 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
627 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
628 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
629 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
630 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
631 ..
632 }) => None,
633 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
634 *last_token_usage.lock() = token_usage;
635 None
636 }
637 Err(err) => Some(Err(err)),
638 }
639 }
640 }
641 }))
642 .boxed();
643
644 Ok(LanguageModelTextStream {
645 message_id,
646 stream,
647 last_token_usage,
648 })
649 }
650 .boxed()
651 }
652
653 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
654 None
655 }
656
657 #[cfg(any(test, feature = "test-support"))]
658 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
659 unimplemented!()
660 }
661}
662
663pub trait LanguageModelExt: LanguageModel {
664 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
665 match mode {
666 CompletionMode::Normal => self.max_token_count(),
667 CompletionMode::Max => self
668 .max_token_count_in_burn_mode()
669 .unwrap_or_else(|| self.max_token_count()),
670 }
671 }
672}
673impl LanguageModelExt for dyn LanguageModel {}
674
675/// An error that occurred when trying to authenticate the language model provider.
676#[derive(Debug, Error)]
677pub enum AuthenticateError {
678 #[error("connection refused")]
679 ConnectionRefused,
680 #[error("credentials not found")]
681 CredentialsNotFound,
682 #[error(transparent)]
683 Other(#[from] anyhow::Error),
684}
685
686pub trait LanguageModelProvider: 'static {
687 fn id(&self) -> LanguageModelProviderId;
688 fn name(&self) -> LanguageModelProviderName;
689 fn icon(&self) -> IconName {
690 IconName::ZedAssistant
691 }
692 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
693 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
694 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
695 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
696 Vec::new()
697 }
698 fn is_authenticated(&self, cx: &App) -> bool;
699 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
700 fn configuration_view(
701 &self,
702 target_agent: ConfigurationViewTargetAgent,
703 window: &mut Window,
704 cx: &mut App,
705 ) -> AnyView;
706 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
707}
708
709#[derive(Default, Clone)]
710pub enum ConfigurationViewTargetAgent {
711 #[default]
712 ZedAgent,
713 Other(SharedString),
714}
715
716#[derive(PartialEq, Eq)]
717pub enum LanguageModelProviderTosView {
718 /// When there are some past interactions in the Agent Panel.
719 ThreadEmptyState,
720 /// When there are no past interactions in the Agent Panel.
721 ThreadFreshStart,
722 TextThreadPopup,
723 Configuration,
724}
725
726pub trait LanguageModelProviderState: 'static {
727 type ObservableEntity;
728
729 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
730
731 fn subscribe<T: 'static>(
732 &self,
733 cx: &mut gpui::Context<T>,
734 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
735 ) -> Option<gpui::Subscription> {
736 let entity = self.observable_entity()?;
737 Some(cx.observe(&entity, move |this, _, cx| {
738 callback(this, cx);
739 }))
740 }
741}
742
743#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
744pub struct LanguageModelId(pub SharedString);
745
746#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
747pub struct LanguageModelName(pub SharedString);
748
749#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
750pub struct LanguageModelProviderId(pub SharedString);
751
752#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
753pub struct LanguageModelProviderName(pub SharedString);
754
755impl LanguageModelProviderId {
756 pub const fn new(id: &'static str) -> Self {
757 Self(SharedString::new_static(id))
758 }
759}
760
761impl LanguageModelProviderName {
762 pub const fn new(id: &'static str) -> Self {
763 Self(SharedString::new_static(id))
764 }
765}
766
767impl fmt::Display for LanguageModelProviderId {
768 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
769 write!(f, "{}", self.0)
770 }
771}
772
773impl fmt::Display for LanguageModelProviderName {
774 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
775 write!(f, "{}", self.0)
776 }
777}
778
779impl From<String> for LanguageModelId {
780 fn from(value: String) -> Self {
781 Self(SharedString::from(value))
782 }
783}
784
785impl From<String> for LanguageModelName {
786 fn from(value: String) -> Self {
787 Self(SharedString::from(value))
788 }
789}
790
791impl From<String> for LanguageModelProviderId {
792 fn from(value: String) -> Self {
793 Self(SharedString::from(value))
794 }
795}
796
797impl From<String> for LanguageModelProviderName {
798 fn from(value: String) -> Self {
799 Self(SharedString::from(value))
800 }
801}
802
803impl From<Arc<str>> for LanguageModelProviderId {
804 fn from(value: Arc<str>) -> Self {
805 Self(SharedString::from(value))
806 }
807}
808
809impl From<Arc<str>> for LanguageModelProviderName {
810 fn from(value: Arc<str>) -> Self {
811 Self(SharedString::from(value))
812 }
813}
814
815#[cfg(test)]
816mod tests {
817 use super::*;
818
819 #[test]
820 fn test_from_cloud_failure_with_upstream_http_error() {
821 let error = LanguageModelCompletionError::from_cloud_failure(
822 String::from("anthropic").into(),
823 "upstream_http_error".to_string(),
824 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
825 None,
826 );
827
828 match error {
829 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
830 assert_eq!(provider.0, "anthropic");
831 }
832 _ => panic!(
833 "Expected ServerOverloaded error for 503 status, got: {:?}",
834 error
835 ),
836 }
837
838 let error = LanguageModelCompletionError::from_cloud_failure(
839 String::from("anthropic").into(),
840 "upstream_http_error".to_string(),
841 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
842 None,
843 );
844
845 match error {
846 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
847 assert_eq!(provider.0, "anthropic");
848 assert_eq!(message, "Internal server error");
849 }
850 _ => panic!(
851 "Expected ApiInternalServerError for 500 status, got: {:?}",
852 error
853 ),
854 }
855 }
856
857 #[test]
858 fn test_from_cloud_failure_with_standard_format() {
859 let error = LanguageModelCompletionError::from_cloud_failure(
860 String::from("anthropic").into(),
861 "upstream_http_503".to_string(),
862 "Service unavailable".to_string(),
863 None,
864 );
865
866 match error {
867 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
868 assert_eq!(provider.0, "anthropic");
869 }
870 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
871 }
872 }
873
874 #[test]
875 fn test_upstream_http_error_connection_timeout() {
876 let error = LanguageModelCompletionError::from_cloud_failure(
877 String::from("anthropic").into(),
878 "upstream_http_error".to_string(),
879 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
880 None,
881 );
882
883 match error {
884 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
885 assert_eq!(provider.0, "anthropic");
886 }
887 _ => panic!(
888 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
889 error
890 ),
891 }
892
893 let error = LanguageModelCompletionError::from_cloud_failure(
894 String::from("anthropic").into(),
895 "upstream_http_error".to_string(),
896 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
897 None,
898 );
899
900 match error {
901 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
902 assert_eq!(provider.0, "anthropic");
903 assert_eq!(
904 message,
905 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
906 );
907 }
908 _ => panic!(
909 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
910 error
911 ),
912 }
913 }
914}