1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7pub mod tool_schema;
8
9#[cfg(any(test, feature = "test-support"))]
10pub mod fake_provider;
11
12use anthropic::{AnthropicError, parse_prompt_too_long};
13use anyhow::{Result, anyhow};
14use client::Client;
15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
16use futures::FutureExt;
17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
19use http_client::{StatusCode, http};
20use icons::IconName;
21use open_router::OpenRouterError;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24pub use settings::LanguageModelCacheConfiguration;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
40
41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
42 LanguageModelProviderId::new("anthropic");
43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
44 LanguageModelProviderName::new("Anthropic");
45
46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Google AI");
49
50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("OpenAI");
53
54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
56
57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
59 LanguageModelProviderName::new("Zed");
60
61pub fn init(client: Arc<Client>, cx: &mut App) {
62 init_settings(cx);
63 RefreshLlmTokenListener::register(client, cx);
64}
65
66pub fn init_settings(cx: &mut App) {
67 registry::init(cx);
68}
69
70/// A completion event from a language model.
71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
72pub enum LanguageModelCompletionEvent {
73 StatusUpdate(CompletionRequestStatus),
74 Stop(StopReason),
75 Text(String),
76 Thinking {
77 text: String,
78 signature: Option<String>,
79 },
80 RedactedThinking {
81 data: String,
82 },
83 ToolUse(LanguageModelToolUse),
84 ToolUseJsonParseError {
85 id: LanguageModelToolUseId,
86 tool_name: Arc<str>,
87 raw_input: Arc<str>,
88 json_parse_error: String,
89 },
90 StartMessage {
91 message_id: String,
92 },
93 UsageUpdate(TokenUsage),
94}
95
96#[derive(Error, Debug)]
97pub enum LanguageModelCompletionError {
98 #[error("prompt too large for context window")]
99 PromptTooLarge { tokens: Option<u64> },
100 #[error("missing {provider} API key")]
101 NoApiKey { provider: LanguageModelProviderName },
102 #[error("{provider}'s API rate limit exceeded")]
103 RateLimitExceeded {
104 provider: LanguageModelProviderName,
105 retry_after: Option<Duration>,
106 },
107 #[error("{provider}'s API servers are overloaded right now")]
108 ServerOverloaded {
109 provider: LanguageModelProviderName,
110 retry_after: Option<Duration>,
111 },
112 #[error("{provider}'s API server reported an internal server error: {message}")]
113 ApiInternalServerError {
114 provider: LanguageModelProviderName,
115 message: String,
116 },
117 #[error("{message}")]
118 UpstreamProviderError {
119 message: String,
120 status: StatusCode,
121 retry_after: Option<Duration>,
122 },
123 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
124 HttpResponseError {
125 provider: LanguageModelProviderName,
126 status_code: StatusCode,
127 message: String,
128 },
129
130 // Client errors
131 #[error("invalid request format to {provider}'s API: {message}")]
132 BadRequestFormat {
133 provider: LanguageModelProviderName,
134 message: String,
135 },
136 #[error("authentication error with {provider}'s API: {message}")]
137 AuthenticationError {
138 provider: LanguageModelProviderName,
139 message: String,
140 },
141 #[error("Permission error with {provider}'s API: {message}")]
142 PermissionError {
143 provider: LanguageModelProviderName,
144 message: String,
145 },
146 #[error("language model provider API endpoint not found")]
147 ApiEndpointNotFound { provider: LanguageModelProviderName },
148 #[error("I/O error reading response from {provider}'s API")]
149 ApiReadResponseError {
150 provider: LanguageModelProviderName,
151 #[source]
152 error: io::Error,
153 },
154 #[error("error serializing request to {provider} API")]
155 SerializeRequest {
156 provider: LanguageModelProviderName,
157 #[source]
158 error: serde_json::Error,
159 },
160 #[error("error building request body to {provider} API")]
161 BuildRequestBody {
162 provider: LanguageModelProviderName,
163 #[source]
164 error: http::Error,
165 },
166 #[error("error sending HTTP request to {provider} API")]
167 HttpSend {
168 provider: LanguageModelProviderName,
169 #[source]
170 error: anyhow::Error,
171 },
172 #[error("error deserializing {provider} API response")]
173 DeserializeResponse {
174 provider: LanguageModelProviderName,
175 #[source]
176 error: serde_json::Error,
177 },
178
179 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
180 #[error(transparent)]
181 Other(#[from] anyhow::Error),
182}
183
184impl LanguageModelCompletionError {
185 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
186 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
187 let upstream_status = error_json
188 .get("upstream_status")
189 .and_then(|v| v.as_u64())
190 .and_then(|status| u16::try_from(status).ok())
191 .and_then(|status| StatusCode::from_u16(status).ok())?;
192 let inner_message = error_json
193 .get("message")
194 .and_then(|v| v.as_str())
195 .unwrap_or(message)
196 .to_string();
197 Some((upstream_status, inner_message))
198 }
199
200 pub fn from_cloud_failure(
201 upstream_provider: LanguageModelProviderName,
202 code: String,
203 message: String,
204 retry_after: Option<Duration>,
205 ) -> Self {
206 if let Some(tokens) = parse_prompt_too_long(&message) {
207 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
208 // to be reported. This is a temporary workaround to handle this in the case where the
209 // token limit has been exceeded.
210 Self::PromptTooLarge {
211 tokens: Some(tokens),
212 }
213 } else if code == "upstream_http_error" {
214 if let Some((upstream_status, inner_message)) =
215 Self::parse_upstream_error_json(&message)
216 {
217 return Self::from_http_status(
218 upstream_provider,
219 upstream_status,
220 inner_message,
221 retry_after,
222 );
223 }
224 anyhow!("completion request failed, code: {code}, message: {message}").into()
225 } else if let Some(status_code) = code
226 .strip_prefix("upstream_http_")
227 .and_then(|code| StatusCode::from_str(code).ok())
228 {
229 Self::from_http_status(upstream_provider, status_code, message, retry_after)
230 } else if let Some(status_code) = code
231 .strip_prefix("http_")
232 .and_then(|code| StatusCode::from_str(code).ok())
233 {
234 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
235 } else {
236 anyhow!("completion request failed, code: {code}, message: {message}").into()
237 }
238 }
239
240 pub fn from_http_status(
241 provider: LanguageModelProviderName,
242 status_code: StatusCode,
243 message: String,
244 retry_after: Option<Duration>,
245 ) -> Self {
246 match status_code {
247 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
248 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
249 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
250 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
251 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
252 tokens: parse_prompt_too_long(&message),
253 },
254 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
255 provider,
256 retry_after,
257 },
258 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
259 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
260 provider,
261 retry_after,
262 },
263 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
264 provider,
265 retry_after,
266 },
267 _ => Self::HttpResponseError {
268 provider,
269 status_code,
270 message,
271 },
272 }
273 }
274}
275
276impl From<AnthropicError> for LanguageModelCompletionError {
277 fn from(error: AnthropicError) -> Self {
278 let provider = ANTHROPIC_PROVIDER_NAME;
279 match error {
280 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
281 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
282 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
283 AnthropicError::DeserializeResponse(error) => {
284 Self::DeserializeResponse { provider, error }
285 }
286 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
287 AnthropicError::HttpResponseError {
288 status_code,
289 message,
290 } => Self::HttpResponseError {
291 provider,
292 status_code,
293 message,
294 },
295 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
296 provider,
297 retry_after: Some(retry_after),
298 },
299 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
300 provider,
301 retry_after,
302 },
303 AnthropicError::ApiError(api_error) => api_error.into(),
304 }
305 }
306}
307
308impl From<anthropic::ApiError> for LanguageModelCompletionError {
309 fn from(error: anthropic::ApiError) -> Self {
310 use anthropic::ApiErrorCode::*;
311 let provider = ANTHROPIC_PROVIDER_NAME;
312 match error.code() {
313 Some(code) => match code {
314 InvalidRequestError => Self::BadRequestFormat {
315 provider,
316 message: error.message,
317 },
318 AuthenticationError => Self::AuthenticationError {
319 provider,
320 message: error.message,
321 },
322 PermissionError => Self::PermissionError {
323 provider,
324 message: error.message,
325 },
326 NotFoundError => Self::ApiEndpointNotFound { provider },
327 RequestTooLarge => Self::PromptTooLarge {
328 tokens: parse_prompt_too_long(&error.message),
329 },
330 RateLimitError => Self::RateLimitExceeded {
331 provider,
332 retry_after: None,
333 },
334 ApiError => Self::ApiInternalServerError {
335 provider,
336 message: error.message,
337 },
338 OverloadedError => Self::ServerOverloaded {
339 provider,
340 retry_after: None,
341 },
342 },
343 None => Self::Other(error.into()),
344 }
345 }
346}
347
348impl From<open_ai::RequestError> for LanguageModelCompletionError {
349 fn from(error: open_ai::RequestError) -> Self {
350 match error {
351 open_ai::RequestError::HttpResponseError {
352 provider,
353 status_code,
354 body,
355 headers,
356 } => {
357 let retry_after = headers
358 .get(http::header::RETRY_AFTER)
359 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
360 .map(Duration::from_secs);
361
362 Self::from_http_status(provider.into(), status_code, body, retry_after)
363 }
364 open_ai::RequestError::Other(e) => Self::Other(e),
365 }
366 }
367}
368
369impl From<OpenRouterError> for LanguageModelCompletionError {
370 fn from(error: OpenRouterError) -> Self {
371 let provider = LanguageModelProviderName::new("OpenRouter");
372 match error {
373 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
374 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
375 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
376 OpenRouterError::DeserializeResponse(error) => {
377 Self::DeserializeResponse { provider, error }
378 }
379 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
380 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
381 provider,
382 retry_after: Some(retry_after),
383 },
384 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
385 provider,
386 retry_after,
387 },
388 OpenRouterError::ApiError(api_error) => api_error.into(),
389 }
390 }
391}
392
393impl From<open_router::ApiError> for LanguageModelCompletionError {
394 fn from(error: open_router::ApiError) -> Self {
395 use open_router::ApiErrorCode::*;
396 let provider = LanguageModelProviderName::new("OpenRouter");
397 match error.code {
398 InvalidRequestError => Self::BadRequestFormat {
399 provider,
400 message: error.message,
401 },
402 AuthenticationError => Self::AuthenticationError {
403 provider,
404 message: error.message,
405 },
406 PaymentRequiredError => Self::AuthenticationError {
407 provider,
408 message: format!("Payment required: {}", error.message),
409 },
410 PermissionError => Self::PermissionError {
411 provider,
412 message: error.message,
413 },
414 RequestTimedOut => Self::HttpResponseError {
415 provider,
416 status_code: StatusCode::REQUEST_TIMEOUT,
417 message: error.message,
418 },
419 RateLimitError => Self::RateLimitExceeded {
420 provider,
421 retry_after: None,
422 },
423 ApiError => Self::ApiInternalServerError {
424 provider,
425 message: error.message,
426 },
427 OverloadedError => Self::ServerOverloaded {
428 provider,
429 retry_after: None,
430 },
431 }
432 }
433}
434
435#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
436#[serde(rename_all = "snake_case")]
437pub enum StopReason {
438 EndTurn,
439 MaxTokens,
440 ToolUse,
441 Refusal,
442}
443
444#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
445pub struct TokenUsage {
446 #[serde(default, skip_serializing_if = "is_default")]
447 pub input_tokens: u64,
448 #[serde(default, skip_serializing_if = "is_default")]
449 pub output_tokens: u64,
450 #[serde(default, skip_serializing_if = "is_default")]
451 pub cache_creation_input_tokens: u64,
452 #[serde(default, skip_serializing_if = "is_default")]
453 pub cache_read_input_tokens: u64,
454}
455
456impl TokenUsage {
457 pub fn total_tokens(&self) -> u64 {
458 self.input_tokens
459 + self.output_tokens
460 + self.cache_read_input_tokens
461 + self.cache_creation_input_tokens
462 }
463}
464
465impl Add<TokenUsage> for TokenUsage {
466 type Output = Self;
467
468 fn add(self, other: Self) -> Self {
469 Self {
470 input_tokens: self.input_tokens + other.input_tokens,
471 output_tokens: self.output_tokens + other.output_tokens,
472 cache_creation_input_tokens: self.cache_creation_input_tokens
473 + other.cache_creation_input_tokens,
474 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
475 }
476 }
477}
478
479impl Sub<TokenUsage> for TokenUsage {
480 type Output = Self;
481
482 fn sub(self, other: Self) -> Self {
483 Self {
484 input_tokens: self.input_tokens - other.input_tokens,
485 output_tokens: self.output_tokens - other.output_tokens,
486 cache_creation_input_tokens: self.cache_creation_input_tokens
487 - other.cache_creation_input_tokens,
488 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
489 }
490 }
491}
492
493#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
494pub struct LanguageModelToolUseId(Arc<str>);
495
496impl fmt::Display for LanguageModelToolUseId {
497 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
498 write!(f, "{}", self.0)
499 }
500}
501
502impl<T> From<T> for LanguageModelToolUseId
503where
504 T: Into<Arc<str>>,
505{
506 fn from(value: T) -> Self {
507 Self(value.into())
508 }
509}
510
511#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
512pub struct LanguageModelToolUse {
513 pub id: LanguageModelToolUseId,
514 pub name: Arc<str>,
515 pub raw_input: String,
516 pub input: serde_json::Value,
517 pub is_input_complete: bool,
518 /// Thought signature the model sent us. Some models require that this
519 /// signature be preserved and sent back in conversation history for validation.
520 pub thought_signature: Option<String>,
521}
522
523pub struct LanguageModelTextStream {
524 pub message_id: Option<String>,
525 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
526 // Has complete token usage after the stream has finished
527 pub last_token_usage: Arc<Mutex<TokenUsage>>,
528}
529
530impl Default for LanguageModelTextStream {
531 fn default() -> Self {
532 Self {
533 message_id: None,
534 stream: Box::pin(futures::stream::empty()),
535 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
536 }
537 }
538}
539
540pub trait LanguageModel: Send + Sync {
541 fn id(&self) -> LanguageModelId;
542 fn name(&self) -> LanguageModelName;
543 fn provider_id(&self) -> LanguageModelProviderId;
544 fn provider_name(&self) -> LanguageModelProviderName;
545 fn upstream_provider_id(&self) -> LanguageModelProviderId {
546 self.provider_id()
547 }
548 fn upstream_provider_name(&self) -> LanguageModelProviderName {
549 self.provider_name()
550 }
551
552 fn telemetry_id(&self) -> String;
553
554 fn api_key(&self, _cx: &App) -> Option<String> {
555 None
556 }
557
558 /// Whether this model supports images
559 fn supports_images(&self) -> bool;
560
561 /// Whether this model supports tools.
562 fn supports_tools(&self) -> bool;
563
564 /// Whether this model supports choosing which tool to use.
565 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
566
567 /// Returns whether this model supports "burn mode";
568 fn supports_burn_mode(&self) -> bool {
569 false
570 }
571
572 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
573 LanguageModelToolSchemaFormat::JsonSchema
574 }
575
576 fn max_token_count(&self) -> u64;
577 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
578 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
579 None
580 }
581 fn max_output_tokens(&self) -> Option<u64> {
582 None
583 }
584
585 fn count_tokens(
586 &self,
587 request: LanguageModelRequest,
588 cx: &App,
589 ) -> BoxFuture<'static, Result<u64>>;
590
591 fn stream_completion(
592 &self,
593 request: LanguageModelRequest,
594 cx: &AsyncApp,
595 ) -> BoxFuture<
596 'static,
597 Result<
598 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
599 LanguageModelCompletionError,
600 >,
601 >;
602
603 fn stream_completion_text(
604 &self,
605 request: LanguageModelRequest,
606 cx: &AsyncApp,
607 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
608 let future = self.stream_completion(request, cx);
609
610 async move {
611 let events = future.await?;
612 let mut events = events.fuse();
613 let mut message_id = None;
614 let mut first_item_text = None;
615 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
616
617 if let Some(first_event) = events.next().await {
618 match first_event {
619 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
620 message_id = Some(id);
621 }
622 Ok(LanguageModelCompletionEvent::Text(text)) => {
623 first_item_text = Some(text);
624 }
625 _ => (),
626 }
627 }
628
629 let stream = futures::stream::iter(first_item_text.map(Ok))
630 .chain(events.filter_map({
631 let last_token_usage = last_token_usage.clone();
632 move |result| {
633 let last_token_usage = last_token_usage.clone();
634 async move {
635 match result {
636 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
637 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
638 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
639 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
640 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
641 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
642 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
643 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
644 ..
645 }) => None,
646 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
647 *last_token_usage.lock() = token_usage;
648 None
649 }
650 Err(err) => Some(Err(err)),
651 }
652 }
653 }
654 }))
655 .boxed();
656
657 Ok(LanguageModelTextStream {
658 message_id,
659 stream,
660 last_token_usage,
661 })
662 }
663 .boxed()
664 }
665
666 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
667 None
668 }
669
670 #[cfg(any(test, feature = "test-support"))]
671 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
672 unimplemented!()
673 }
674}
675
676pub trait LanguageModelExt: LanguageModel {
677 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
678 match mode {
679 CompletionMode::Normal => self.max_token_count(),
680 CompletionMode::Max => self
681 .max_token_count_in_burn_mode()
682 .unwrap_or_else(|| self.max_token_count()),
683 }
684 }
685}
686impl LanguageModelExt for dyn LanguageModel {}
687
688/// An error that occurred when trying to authenticate the language model provider.
689#[derive(Debug, Error)]
690pub enum AuthenticateError {
691 #[error("connection refused")]
692 ConnectionRefused,
693 #[error("credentials not found")]
694 CredentialsNotFound,
695 #[error(transparent)]
696 Other(#[from] anyhow::Error),
697}
698
699pub trait LanguageModelProvider: 'static {
700 fn id(&self) -> LanguageModelProviderId;
701 fn name(&self) -> LanguageModelProviderName;
702 fn icon(&self) -> IconName {
703 IconName::ZedAssistant
704 }
705 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
706 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
707 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
708 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
709 Vec::new()
710 }
711 fn is_authenticated(&self, cx: &App) -> bool;
712 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
713 fn configuration_view(
714 &self,
715 target_agent: ConfigurationViewTargetAgent,
716 window: &mut Window,
717 cx: &mut App,
718 ) -> AnyView;
719 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
720}
721
722#[derive(Default, Clone)]
723pub enum ConfigurationViewTargetAgent {
724 #[default]
725 ZedAgent,
726 Other(SharedString),
727}
728
729#[derive(PartialEq, Eq)]
730pub enum LanguageModelProviderTosView {
731 /// When there are some past interactions in the Agent Panel.
732 ThreadEmptyState,
733 /// When there are no past interactions in the Agent Panel.
734 ThreadFreshStart,
735 TextThreadPopup,
736 Configuration,
737}
738
739pub trait LanguageModelProviderState: 'static {
740 type ObservableEntity;
741
742 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
743
744 fn subscribe<T: 'static>(
745 &self,
746 cx: &mut gpui::Context<T>,
747 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
748 ) -> Option<gpui::Subscription> {
749 let entity = self.observable_entity()?;
750 Some(cx.observe(&entity, move |this, _, cx| {
751 callback(this, cx);
752 }))
753 }
754}
755
756#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
757pub struct LanguageModelId(pub SharedString);
758
759#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
760pub struct LanguageModelName(pub SharedString);
761
762#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
763pub struct LanguageModelProviderId(pub SharedString);
764
765#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
766pub struct LanguageModelProviderName(pub SharedString);
767
768impl LanguageModelProviderId {
769 pub const fn new(id: &'static str) -> Self {
770 Self(SharedString::new_static(id))
771 }
772}
773
774impl LanguageModelProviderName {
775 pub const fn new(id: &'static str) -> Self {
776 Self(SharedString::new_static(id))
777 }
778}
779
780impl fmt::Display for LanguageModelProviderId {
781 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
782 write!(f, "{}", self.0)
783 }
784}
785
786impl fmt::Display for LanguageModelProviderName {
787 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
788 write!(f, "{}", self.0)
789 }
790}
791
792impl From<String> for LanguageModelId {
793 fn from(value: String) -> Self {
794 Self(SharedString::from(value))
795 }
796}
797
798impl From<String> for LanguageModelName {
799 fn from(value: String) -> Self {
800 Self(SharedString::from(value))
801 }
802}
803
804impl From<String> for LanguageModelProviderId {
805 fn from(value: String) -> Self {
806 Self(SharedString::from(value))
807 }
808}
809
810impl From<String> for LanguageModelProviderName {
811 fn from(value: String) -> Self {
812 Self(SharedString::from(value))
813 }
814}
815
816impl From<Arc<str>> for LanguageModelProviderId {
817 fn from(value: Arc<str>) -> Self {
818 Self(SharedString::from(value))
819 }
820}
821
822impl From<Arc<str>> for LanguageModelProviderName {
823 fn from(value: Arc<str>) -> Self {
824 Self(SharedString::from(value))
825 }
826}
827
828#[cfg(test)]
829mod tests {
830 use super::*;
831
832 #[test]
833 fn test_from_cloud_failure_with_upstream_http_error() {
834 let error = LanguageModelCompletionError::from_cloud_failure(
835 String::from("anthropic").into(),
836 "upstream_http_error".to_string(),
837 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
838 None,
839 );
840
841 match error {
842 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
843 assert_eq!(provider.0, "anthropic");
844 }
845 _ => panic!(
846 "Expected ServerOverloaded error for 503 status, got: {:?}",
847 error
848 ),
849 }
850
851 let error = LanguageModelCompletionError::from_cloud_failure(
852 String::from("anthropic").into(),
853 "upstream_http_error".to_string(),
854 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
855 None,
856 );
857
858 match error {
859 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
860 assert_eq!(provider.0, "anthropic");
861 assert_eq!(message, "Internal server error");
862 }
863 _ => panic!(
864 "Expected ApiInternalServerError for 500 status, got: {:?}",
865 error
866 ),
867 }
868 }
869
870 #[test]
871 fn test_from_cloud_failure_with_standard_format() {
872 let error = LanguageModelCompletionError::from_cloud_failure(
873 String::from("anthropic").into(),
874 "upstream_http_503".to_string(),
875 "Service unavailable".to_string(),
876 None,
877 );
878
879 match error {
880 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
881 assert_eq!(provider.0, "anthropic");
882 }
883 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
884 }
885 }
886
887 #[test]
888 fn test_upstream_http_error_connection_timeout() {
889 let error = LanguageModelCompletionError::from_cloud_failure(
890 String::from("anthropic").into(),
891 "upstream_http_error".to_string(),
892 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
893 None,
894 );
895
896 match error {
897 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
898 assert_eq!(provider.0, "anthropic");
899 }
900 _ => panic!(
901 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
902 error
903 ),
904 }
905
906 let error = LanguageModelCompletionError::from_cloud_failure(
907 String::from("anthropic").into(),
908 "upstream_http_error".to_string(),
909 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
910 None,
911 );
912
913 match error {
914 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
915 assert_eq!(provider.0, "anthropic");
916 assert_eq!(
917 message,
918 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
919 );
920 }
921 _ => panic!(
922 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
923 error
924 ),
925 }
926 }
927
928 #[test]
929 fn test_language_model_tool_use_serializes_with_signature() {
930 use serde_json::json;
931
932 let tool_use = LanguageModelToolUse {
933 id: LanguageModelToolUseId::from("test_id"),
934 name: "test_tool".into(),
935 raw_input: json!({"arg": "value"}).to_string(),
936 input: json!({"arg": "value"}),
937 is_input_complete: true,
938 thought_signature: Some("test_signature".to_string()),
939 };
940
941 let serialized = serde_json::to_value(&tool_use).unwrap();
942
943 assert_eq!(serialized["id"], "test_id");
944 assert_eq!(serialized["name"], "test_tool");
945 assert_eq!(serialized["thought_signature"], "test_signature");
946 }
947
948 #[test]
949 fn test_language_model_tool_use_deserializes_with_missing_signature() {
950 use serde_json::json;
951
952 let json = json!({
953 "id": "test_id",
954 "name": "test_tool",
955 "raw_input": "{\"arg\":\"value\"}",
956 "input": {"arg": "value"},
957 "is_input_complete": true
958 });
959
960 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
961
962 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
963 assert_eq!(tool_use.name.as_ref(), "test_tool");
964 assert_eq!(tool_use.thought_signature, None);
965 }
966
967 #[test]
968 fn test_language_model_tool_use_round_trip_with_signature() {
969 use serde_json::json;
970
971 let original = LanguageModelToolUse {
972 id: LanguageModelToolUseId::from("round_trip_id"),
973 name: "round_trip_tool".into(),
974 raw_input: json!({"key": "value"}).to_string(),
975 input: json!({"key": "value"}),
976 is_input_complete: true,
977 thought_signature: Some("round_trip_sig".to_string()),
978 };
979
980 let serialized = serde_json::to_value(&original).unwrap();
981 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
982
983 assert_eq!(deserialized.id, original.id);
984 assert_eq!(deserialized.name, original.name);
985 assert_eq!(deserialized.thought_signature, original.thought_signature);
986 }
987
988 #[test]
989 fn test_language_model_tool_use_round_trip_without_signature() {
990 use serde_json::json;
991
992 let original = LanguageModelToolUse {
993 id: LanguageModelToolUseId::from("no_sig_id"),
994 name: "no_sig_tool".into(),
995 raw_input: json!({"key": "value"}).to_string(),
996 input: json!({"key": "value"}),
997 is_input_complete: true,
998 thought_signature: None,
999 };
1000
1001 let serialized = serde_json::to_value(&original).unwrap();
1002 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1003
1004 assert_eq!(deserialized.id, original.id);
1005 assert_eq!(deserialized.name, original.name);
1006 assert_eq!(deserialized.thought_signature, None);
1007 }
1008}