1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7pub mod tool_schema;
8
9#[cfg(any(test, feature = "test-support"))]
10pub mod fake_provider;
11
12use anthropic::{AnthropicError, parse_prompt_too_long};
13use anyhow::{Result, anyhow};
14use client::Client;
15use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
16use futures::FutureExt;
17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
19use http_client::{StatusCode, http};
20use icons::IconName;
21use open_router::OpenRouterError;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24pub use settings::LanguageModelCacheConfiguration;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
40
41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
42 LanguageModelProviderId::new("anthropic");
43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
44 LanguageModelProviderName::new("Anthropic");
45
46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Google AI");
49
50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("OpenAI");
53
54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
56
57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
59 LanguageModelProviderName::new("Zed");
60
61pub fn init(client: Arc<Client>, cx: &mut App) {
62 init_settings(cx);
63 RefreshLlmTokenListener::register(client, cx);
64}
65
66pub fn init_settings(cx: &mut App) {
67 registry::init(cx);
68}
69
70/// A completion event from a language model.
71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
72pub enum LanguageModelCompletionEvent {
73 StatusUpdate(CompletionRequestStatus),
74 Stop(StopReason),
75 Text(String),
76 Thinking {
77 text: String,
78 signature: Option<String>,
79 },
80 RedactedThinking {
81 data: String,
82 },
83 ToolUse(LanguageModelToolUse),
84 ToolUseJsonParseError {
85 id: LanguageModelToolUseId,
86 tool_name: Arc<str>,
87 raw_input: Arc<str>,
88 json_parse_error: String,
89 },
90 StartMessage {
91 message_id: String,
92 },
93 ReasoningDetails(serde_json::Value),
94 UsageUpdate(TokenUsage),
95}
96
97#[derive(Error, Debug)]
98pub enum LanguageModelCompletionError {
99 #[error("prompt too large for context window")]
100 PromptTooLarge { tokens: Option<u64> },
101 #[error("missing {provider} API key")]
102 NoApiKey { provider: LanguageModelProviderName },
103 #[error("{provider}'s API rate limit exceeded")]
104 RateLimitExceeded {
105 provider: LanguageModelProviderName,
106 retry_after: Option<Duration>,
107 },
108 #[error("{provider}'s API servers are overloaded right now")]
109 ServerOverloaded {
110 provider: LanguageModelProviderName,
111 retry_after: Option<Duration>,
112 },
113 #[error("{provider}'s API server reported an internal server error: {message}")]
114 ApiInternalServerError {
115 provider: LanguageModelProviderName,
116 message: String,
117 },
118 #[error("{message}")]
119 UpstreamProviderError {
120 message: String,
121 status: StatusCode,
122 retry_after: Option<Duration>,
123 },
124 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
125 HttpResponseError {
126 provider: LanguageModelProviderName,
127 status_code: StatusCode,
128 message: String,
129 },
130
131 // Client errors
132 #[error("invalid request format to {provider}'s API: {message}")]
133 BadRequestFormat {
134 provider: LanguageModelProviderName,
135 message: String,
136 },
137 #[error("authentication error with {provider}'s API: {message}")]
138 AuthenticationError {
139 provider: LanguageModelProviderName,
140 message: String,
141 },
142 #[error("Permission error with {provider}'s API: {message}")]
143 PermissionError {
144 provider: LanguageModelProviderName,
145 message: String,
146 },
147 #[error("language model provider API endpoint not found")]
148 ApiEndpointNotFound { provider: LanguageModelProviderName },
149 #[error("I/O error reading response from {provider}'s API")]
150 ApiReadResponseError {
151 provider: LanguageModelProviderName,
152 #[source]
153 error: io::Error,
154 },
155 #[error("error serializing request to {provider} API")]
156 SerializeRequest {
157 provider: LanguageModelProviderName,
158 #[source]
159 error: serde_json::Error,
160 },
161 #[error("error building request body to {provider} API")]
162 BuildRequestBody {
163 provider: LanguageModelProviderName,
164 #[source]
165 error: http::Error,
166 },
167 #[error("error sending HTTP request to {provider} API")]
168 HttpSend {
169 provider: LanguageModelProviderName,
170 #[source]
171 error: anyhow::Error,
172 },
173 #[error("error deserializing {provider} API response")]
174 DeserializeResponse {
175 provider: LanguageModelProviderName,
176 #[source]
177 error: serde_json::Error,
178 },
179
180 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
181 #[error(transparent)]
182 Other(#[from] anyhow::Error),
183}
184
185impl LanguageModelCompletionError {
186 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
187 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
188 let upstream_status = error_json
189 .get("upstream_status")
190 .and_then(|v| v.as_u64())
191 .and_then(|status| u16::try_from(status).ok())
192 .and_then(|status| StatusCode::from_u16(status).ok())?;
193 let inner_message = error_json
194 .get("message")
195 .and_then(|v| v.as_str())
196 .unwrap_or(message)
197 .to_string();
198 Some((upstream_status, inner_message))
199 }
200
201 pub fn from_cloud_failure(
202 upstream_provider: LanguageModelProviderName,
203 code: String,
204 message: String,
205 retry_after: Option<Duration>,
206 ) -> Self {
207 if let Some(tokens) = parse_prompt_too_long(&message) {
208 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
209 // to be reported. This is a temporary workaround to handle this in the case where the
210 // token limit has been exceeded.
211 Self::PromptTooLarge {
212 tokens: Some(tokens),
213 }
214 } else if code == "upstream_http_error" {
215 if let Some((upstream_status, inner_message)) =
216 Self::parse_upstream_error_json(&message)
217 {
218 return Self::from_http_status(
219 upstream_provider,
220 upstream_status,
221 inner_message,
222 retry_after,
223 );
224 }
225 anyhow!("completion request failed, code: {code}, message: {message}").into()
226 } else if let Some(status_code) = code
227 .strip_prefix("upstream_http_")
228 .and_then(|code| StatusCode::from_str(code).ok())
229 {
230 Self::from_http_status(upstream_provider, status_code, message, retry_after)
231 } else if let Some(status_code) = code
232 .strip_prefix("http_")
233 .and_then(|code| StatusCode::from_str(code).ok())
234 {
235 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
236 } else {
237 anyhow!("completion request failed, code: {code}, message: {message}").into()
238 }
239 }
240
241 pub fn from_http_status(
242 provider: LanguageModelProviderName,
243 status_code: StatusCode,
244 message: String,
245 retry_after: Option<Duration>,
246 ) -> Self {
247 match status_code {
248 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
249 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
250 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
251 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
252 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
253 tokens: parse_prompt_too_long(&message),
254 },
255 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
256 provider,
257 retry_after,
258 },
259 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
260 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
261 provider,
262 retry_after,
263 },
264 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
265 provider,
266 retry_after,
267 },
268 _ => Self::HttpResponseError {
269 provider,
270 status_code,
271 message,
272 },
273 }
274 }
275}
276
277impl From<AnthropicError> for LanguageModelCompletionError {
278 fn from(error: AnthropicError) -> Self {
279 let provider = ANTHROPIC_PROVIDER_NAME;
280 match error {
281 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
282 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
283 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
284 AnthropicError::DeserializeResponse(error) => {
285 Self::DeserializeResponse { provider, error }
286 }
287 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
288 AnthropicError::HttpResponseError {
289 status_code,
290 message,
291 } => Self::HttpResponseError {
292 provider,
293 status_code,
294 message,
295 },
296 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
297 provider,
298 retry_after: Some(retry_after),
299 },
300 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
301 provider,
302 retry_after,
303 },
304 AnthropicError::ApiError(api_error) => api_error.into(),
305 }
306 }
307}
308
309impl From<anthropic::ApiError> for LanguageModelCompletionError {
310 fn from(error: anthropic::ApiError) -> Self {
311 use anthropic::ApiErrorCode::*;
312 let provider = ANTHROPIC_PROVIDER_NAME;
313 match error.code() {
314 Some(code) => match code {
315 InvalidRequestError => Self::BadRequestFormat {
316 provider,
317 message: error.message,
318 },
319 AuthenticationError => Self::AuthenticationError {
320 provider,
321 message: error.message,
322 },
323 PermissionError => Self::PermissionError {
324 provider,
325 message: error.message,
326 },
327 NotFoundError => Self::ApiEndpointNotFound { provider },
328 RequestTooLarge => Self::PromptTooLarge {
329 tokens: parse_prompt_too_long(&error.message),
330 },
331 RateLimitError => Self::RateLimitExceeded {
332 provider,
333 retry_after: None,
334 },
335 ApiError => Self::ApiInternalServerError {
336 provider,
337 message: error.message,
338 },
339 OverloadedError => Self::ServerOverloaded {
340 provider,
341 retry_after: None,
342 },
343 },
344 None => Self::Other(error.into()),
345 }
346 }
347}
348
349impl From<open_ai::RequestError> for LanguageModelCompletionError {
350 fn from(error: open_ai::RequestError) -> Self {
351 match error {
352 open_ai::RequestError::HttpResponseError {
353 provider,
354 status_code,
355 body,
356 headers,
357 } => {
358 let retry_after = headers
359 .get(http::header::RETRY_AFTER)
360 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
361 .map(Duration::from_secs);
362
363 Self::from_http_status(provider.into(), status_code, body, retry_after)
364 }
365 open_ai::RequestError::Other(e) => Self::Other(e),
366 }
367 }
368}
369
370impl From<OpenRouterError> for LanguageModelCompletionError {
371 fn from(error: OpenRouterError) -> Self {
372 let provider = LanguageModelProviderName::new("OpenRouter");
373 match error {
374 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
375 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
376 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
377 OpenRouterError::DeserializeResponse(error) => {
378 Self::DeserializeResponse { provider, error }
379 }
380 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
381 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
382 provider,
383 retry_after: Some(retry_after),
384 },
385 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
386 provider,
387 retry_after,
388 },
389 OpenRouterError::ApiError(api_error) => api_error.into(),
390 }
391 }
392}
393
394impl From<open_router::ApiError> for LanguageModelCompletionError {
395 fn from(error: open_router::ApiError) -> Self {
396 use open_router::ApiErrorCode::*;
397 let provider = LanguageModelProviderName::new("OpenRouter");
398 match error.code {
399 InvalidRequestError => Self::BadRequestFormat {
400 provider,
401 message: error.message,
402 },
403 AuthenticationError => Self::AuthenticationError {
404 provider,
405 message: error.message,
406 },
407 PaymentRequiredError => Self::AuthenticationError {
408 provider,
409 message: format!("Payment required: {}", error.message),
410 },
411 PermissionError => Self::PermissionError {
412 provider,
413 message: error.message,
414 },
415 RequestTimedOut => Self::HttpResponseError {
416 provider,
417 status_code: StatusCode::REQUEST_TIMEOUT,
418 message: error.message,
419 },
420 RateLimitError => Self::RateLimitExceeded {
421 provider,
422 retry_after: None,
423 },
424 ApiError => Self::ApiInternalServerError {
425 provider,
426 message: error.message,
427 },
428 OverloadedError => Self::ServerOverloaded {
429 provider,
430 retry_after: None,
431 },
432 }
433 }
434}
435
436#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
437#[serde(rename_all = "snake_case")]
438pub enum StopReason {
439 EndTurn,
440 MaxTokens,
441 ToolUse,
442 Refusal,
443}
444
445#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
446pub struct TokenUsage {
447 #[serde(default, skip_serializing_if = "is_default")]
448 pub input_tokens: u64,
449 #[serde(default, skip_serializing_if = "is_default")]
450 pub output_tokens: u64,
451 #[serde(default, skip_serializing_if = "is_default")]
452 pub cache_creation_input_tokens: u64,
453 #[serde(default, skip_serializing_if = "is_default")]
454 pub cache_read_input_tokens: u64,
455}
456
457impl TokenUsage {
458 pub fn total_tokens(&self) -> u64 {
459 self.input_tokens
460 + self.output_tokens
461 + self.cache_read_input_tokens
462 + self.cache_creation_input_tokens
463 }
464}
465
466impl Add<TokenUsage> for TokenUsage {
467 type Output = Self;
468
469 fn add(self, other: Self) -> Self {
470 Self {
471 input_tokens: self.input_tokens + other.input_tokens,
472 output_tokens: self.output_tokens + other.output_tokens,
473 cache_creation_input_tokens: self.cache_creation_input_tokens
474 + other.cache_creation_input_tokens,
475 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
476 }
477 }
478}
479
480impl Sub<TokenUsage> for TokenUsage {
481 type Output = Self;
482
483 fn sub(self, other: Self) -> Self {
484 Self {
485 input_tokens: self.input_tokens - other.input_tokens,
486 output_tokens: self.output_tokens - other.output_tokens,
487 cache_creation_input_tokens: self.cache_creation_input_tokens
488 - other.cache_creation_input_tokens,
489 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
490 }
491 }
492}
493
494#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
495pub struct LanguageModelToolUseId(Arc<str>);
496
497impl fmt::Display for LanguageModelToolUseId {
498 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
499 write!(f, "{}", self.0)
500 }
501}
502
503impl<T> From<T> for LanguageModelToolUseId
504where
505 T: Into<Arc<str>>,
506{
507 fn from(value: T) -> Self {
508 Self(value.into())
509 }
510}
511
512#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
513pub struct LanguageModelToolUse {
514 pub id: LanguageModelToolUseId,
515 pub name: Arc<str>,
516 pub raw_input: String,
517 pub input: serde_json::Value,
518 pub is_input_complete: bool,
519 /// Thought signature the model sent us. Some models require that this
520 /// signature be preserved and sent back in conversation history for validation.
521 pub thought_signature: Option<String>,
522}
523
524pub struct LanguageModelTextStream {
525 pub message_id: Option<String>,
526 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
527 // Has complete token usage after the stream has finished
528 pub last_token_usage: Arc<Mutex<TokenUsage>>,
529}
530
531impl Default for LanguageModelTextStream {
532 fn default() -> Self {
533 Self {
534 message_id: None,
535 stream: Box::pin(futures::stream::empty()),
536 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
537 }
538 }
539}
540
541pub trait LanguageModel: Send + Sync {
542 fn id(&self) -> LanguageModelId;
543 fn name(&self) -> LanguageModelName;
544 fn provider_id(&self) -> LanguageModelProviderId;
545 fn provider_name(&self) -> LanguageModelProviderName;
546 fn upstream_provider_id(&self) -> LanguageModelProviderId {
547 self.provider_id()
548 }
549 fn upstream_provider_name(&self) -> LanguageModelProviderName {
550 self.provider_name()
551 }
552
553 fn telemetry_id(&self) -> String;
554
555 fn api_key(&self, _cx: &App) -> Option<String> {
556 None
557 }
558
559 /// Whether this model supports images
560 fn supports_images(&self) -> bool;
561
562 /// Whether this model supports tools.
563 fn supports_tools(&self) -> bool;
564
565 /// Whether this model supports choosing which tool to use.
566 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
567
568 /// Returns whether this model supports "burn mode";
569 fn supports_burn_mode(&self) -> bool {
570 false
571 }
572
573 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
574 LanguageModelToolSchemaFormat::JsonSchema
575 }
576
577 fn max_token_count(&self) -> u64;
578 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
579 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
580 None
581 }
582 fn max_output_tokens(&self) -> Option<u64> {
583 None
584 }
585
586 fn count_tokens(
587 &self,
588 request: LanguageModelRequest,
589 cx: &App,
590 ) -> BoxFuture<'static, Result<u64>>;
591
592 fn stream_completion(
593 &self,
594 request: LanguageModelRequest,
595 cx: &AsyncApp,
596 ) -> BoxFuture<
597 'static,
598 Result<
599 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
600 LanguageModelCompletionError,
601 >,
602 >;
603
604 fn stream_completion_text(
605 &self,
606 request: LanguageModelRequest,
607 cx: &AsyncApp,
608 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
609 let future = self.stream_completion(request, cx);
610
611 async move {
612 let events = future.await?;
613 let mut events = events.fuse();
614 let mut message_id = None;
615 let mut first_item_text = None;
616 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
617
618 if let Some(first_event) = events.next().await {
619 match first_event {
620 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
621 message_id = Some(id);
622 }
623 Ok(LanguageModelCompletionEvent::Text(text)) => {
624 first_item_text = Some(text);
625 }
626 _ => (),
627 }
628 }
629
630 let stream = futures::stream::iter(first_item_text.map(Ok))
631 .chain(events.filter_map({
632 let last_token_usage = last_token_usage.clone();
633 move |result| {
634 let last_token_usage = last_token_usage.clone();
635 async move {
636 match result {
637 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
638 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
639 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
640 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
641 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
642 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
643 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
644 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
645 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
646 ..
647 }) => None,
648 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
649 *last_token_usage.lock() = token_usage;
650 None
651 }
652 Err(err) => Some(Err(err)),
653 }
654 }
655 }
656 }))
657 .boxed();
658
659 Ok(LanguageModelTextStream {
660 message_id,
661 stream,
662 last_token_usage,
663 })
664 }
665 .boxed()
666 }
667
668 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
669 None
670 }
671
672 #[cfg(any(test, feature = "test-support"))]
673 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
674 unimplemented!()
675 }
676}
677
678pub trait LanguageModelExt: LanguageModel {
679 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
680 match mode {
681 CompletionMode::Normal => self.max_token_count(),
682 CompletionMode::Max => self
683 .max_token_count_in_burn_mode()
684 .unwrap_or_else(|| self.max_token_count()),
685 }
686 }
687}
688impl LanguageModelExt for dyn LanguageModel {}
689
690/// An error that occurred when trying to authenticate the language model provider.
691#[derive(Debug, Error)]
692pub enum AuthenticateError {
693 #[error("connection refused")]
694 ConnectionRefused,
695 #[error("credentials not found")]
696 CredentialsNotFound,
697 #[error(transparent)]
698 Other(#[from] anyhow::Error),
699}
700
701pub trait LanguageModelProvider: 'static {
702 fn id(&self) -> LanguageModelProviderId;
703 fn name(&self) -> LanguageModelProviderName;
704 fn icon(&self) -> IconName {
705 IconName::ZedAssistant
706 }
707 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
708 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
709 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
710 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
711 Vec::new()
712 }
713 fn is_authenticated(&self, cx: &App) -> bool;
714 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
715 fn configuration_view(
716 &self,
717 target_agent: ConfigurationViewTargetAgent,
718 window: &mut Window,
719 cx: &mut App,
720 ) -> AnyView;
721 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
722}
723
724#[derive(Default, Clone)]
725pub enum ConfigurationViewTargetAgent {
726 #[default]
727 ZedAgent,
728 Other(SharedString),
729}
730
731#[derive(PartialEq, Eq)]
732pub enum LanguageModelProviderTosView {
733 /// When there are some past interactions in the Agent Panel.
734 ThreadEmptyState,
735 /// When there are no past interactions in the Agent Panel.
736 ThreadFreshStart,
737 TextThreadPopup,
738 Configuration,
739}
740
741pub trait LanguageModelProviderState: 'static {
742 type ObservableEntity;
743
744 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
745
746 fn subscribe<T: 'static>(
747 &self,
748 cx: &mut gpui::Context<T>,
749 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
750 ) -> Option<gpui::Subscription> {
751 let entity = self.observable_entity()?;
752 Some(cx.observe(&entity, move |this, _, cx| {
753 callback(this, cx);
754 }))
755 }
756}
757
758#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
759pub struct LanguageModelId(pub SharedString);
760
761#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
762pub struct LanguageModelName(pub SharedString);
763
764#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
765pub struct LanguageModelProviderId(pub SharedString);
766
767#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
768pub struct LanguageModelProviderName(pub SharedString);
769
770impl LanguageModelProviderId {
771 pub const fn new(id: &'static str) -> Self {
772 Self(SharedString::new_static(id))
773 }
774}
775
776impl LanguageModelProviderName {
777 pub const fn new(id: &'static str) -> Self {
778 Self(SharedString::new_static(id))
779 }
780}
781
782impl fmt::Display for LanguageModelProviderId {
783 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
784 write!(f, "{}", self.0)
785 }
786}
787
788impl fmt::Display for LanguageModelProviderName {
789 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
790 write!(f, "{}", self.0)
791 }
792}
793
794impl From<String> for LanguageModelId {
795 fn from(value: String) -> Self {
796 Self(SharedString::from(value))
797 }
798}
799
800impl From<String> for LanguageModelName {
801 fn from(value: String) -> Self {
802 Self(SharedString::from(value))
803 }
804}
805
806impl From<String> for LanguageModelProviderId {
807 fn from(value: String) -> Self {
808 Self(SharedString::from(value))
809 }
810}
811
812impl From<String> for LanguageModelProviderName {
813 fn from(value: String) -> Self {
814 Self(SharedString::from(value))
815 }
816}
817
818impl From<Arc<str>> for LanguageModelProviderId {
819 fn from(value: Arc<str>) -> Self {
820 Self(SharedString::from(value))
821 }
822}
823
824impl From<Arc<str>> for LanguageModelProviderName {
825 fn from(value: Arc<str>) -> Self {
826 Self(SharedString::from(value))
827 }
828}
829
830#[cfg(test)]
831mod tests {
832 use super::*;
833
834 #[test]
835 fn test_from_cloud_failure_with_upstream_http_error() {
836 let error = LanguageModelCompletionError::from_cloud_failure(
837 String::from("anthropic").into(),
838 "upstream_http_error".to_string(),
839 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
840 None,
841 );
842
843 match error {
844 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
845 assert_eq!(provider.0, "anthropic");
846 }
847 _ => panic!(
848 "Expected ServerOverloaded error for 503 status, got: {:?}",
849 error
850 ),
851 }
852
853 let error = LanguageModelCompletionError::from_cloud_failure(
854 String::from("anthropic").into(),
855 "upstream_http_error".to_string(),
856 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
857 None,
858 );
859
860 match error {
861 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
862 assert_eq!(provider.0, "anthropic");
863 assert_eq!(message, "Internal server error");
864 }
865 _ => panic!(
866 "Expected ApiInternalServerError for 500 status, got: {:?}",
867 error
868 ),
869 }
870 }
871
872 #[test]
873 fn test_from_cloud_failure_with_standard_format() {
874 let error = LanguageModelCompletionError::from_cloud_failure(
875 String::from("anthropic").into(),
876 "upstream_http_503".to_string(),
877 "Service unavailable".to_string(),
878 None,
879 );
880
881 match error {
882 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
883 assert_eq!(provider.0, "anthropic");
884 }
885 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
886 }
887 }
888
889 #[test]
890 fn test_upstream_http_error_connection_timeout() {
891 let error = LanguageModelCompletionError::from_cloud_failure(
892 String::from("anthropic").into(),
893 "upstream_http_error".to_string(),
894 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
895 None,
896 );
897
898 match error {
899 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
900 assert_eq!(provider.0, "anthropic");
901 }
902 _ => panic!(
903 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
904 error
905 ),
906 }
907
908 let error = LanguageModelCompletionError::from_cloud_failure(
909 String::from("anthropic").into(),
910 "upstream_http_error".to_string(),
911 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
912 None,
913 );
914
915 match error {
916 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
917 assert_eq!(provider.0, "anthropic");
918 assert_eq!(
919 message,
920 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
921 );
922 }
923 _ => panic!(
924 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
925 error
926 ),
927 }
928 }
929
930 #[test]
931 fn test_language_model_tool_use_serializes_with_signature() {
932 use serde_json::json;
933
934 let tool_use = LanguageModelToolUse {
935 id: LanguageModelToolUseId::from("test_id"),
936 name: "test_tool".into(),
937 raw_input: json!({"arg": "value"}).to_string(),
938 input: json!({"arg": "value"}),
939 is_input_complete: true,
940 thought_signature: Some("test_signature".to_string()),
941 };
942
943 let serialized = serde_json::to_value(&tool_use).unwrap();
944
945 assert_eq!(serialized["id"], "test_id");
946 assert_eq!(serialized["name"], "test_tool");
947 assert_eq!(serialized["thought_signature"], "test_signature");
948 }
949
950 #[test]
951 fn test_language_model_tool_use_deserializes_with_missing_signature() {
952 use serde_json::json;
953
954 let json = json!({
955 "id": "test_id",
956 "name": "test_tool",
957 "raw_input": "{\"arg\":\"value\"}",
958 "input": {"arg": "value"},
959 "is_input_complete": true
960 });
961
962 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
963
964 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
965 assert_eq!(tool_use.name.as_ref(), "test_tool");
966 assert_eq!(tool_use.thought_signature, None);
967 }
968
969 #[test]
970 fn test_language_model_tool_use_round_trip_with_signature() {
971 use serde_json::json;
972
973 let original = LanguageModelToolUse {
974 id: LanguageModelToolUseId::from("round_trip_id"),
975 name: "round_trip_tool".into(),
976 raw_input: json!({"key": "value"}).to_string(),
977 input: json!({"key": "value"}),
978 is_input_complete: true,
979 thought_signature: Some("round_trip_sig".to_string()),
980 };
981
982 let serialized = serde_json::to_value(&original).unwrap();
983 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
984
985 assert_eq!(deserialized.id, original.id);
986 assert_eq!(deserialized.name, original.name);
987 assert_eq!(deserialized.thought_signature, original.thought_signature);
988 }
989
990 #[test]
991 fn test_language_model_tool_use_round_trip_without_signature() {
992 use serde_json::json;
993
994 let original = LanguageModelToolUse {
995 id: LanguageModelToolUseId::from("no_sig_id"),
996 name: "no_sig_tool".into(),
997 raw_input: json!({"arg": "value"}).to_string(),
998 input: json!({"arg": "value"}),
999 is_input_complete: true,
1000 thought_signature: None,
1001 };
1002
1003 let serialized = serde_json::to_value(&original).unwrap();
1004 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1005
1006 assert_eq!(deserialized.id, original.id);
1007 assert_eq!(deserialized.name, original.name);
1008 assert_eq!(deserialized.thought_signature, None);
1009 }
1010}