1use std::{
2 sync::atomic::{self, AtomicUsize, Ordering},
3 time::Instant,
4};
5
6use async_trait::async_trait;
7
8use crate::{
9 auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
10 embedding::{Embedding, EmbeddingProvider},
11 models::{LanguageModel, TruncationDirection},
12};
13
14#[derive(Clone)]
15pub struct FakeLanguageModel {
16 pub capacity: usize,
17}
18
19impl LanguageModel for FakeLanguageModel {
20 fn name(&self) -> String {
21 "dummy".to_string()
22 }
23 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
24 anyhow::Ok(content.chars().collect::<Vec<char>>().len())
25 }
26 fn truncate(
27 &self,
28 content: &str,
29 length: usize,
30 direction: TruncationDirection,
31 ) -> anyhow::Result<String> {
32 if length > self.count_tokens(content)? {
33 return anyhow::Ok(content.to_string());
34 }
35
36 anyhow::Ok(match direction {
37 TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
38 .into_iter()
39 .collect::<String>(),
40 TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
41 .into_iter()
42 .collect::<String>(),
43 })
44 }
45 fn capacity(&self) -> anyhow::Result<usize> {
46 anyhow::Ok(self.capacity)
47 }
48}
49
50pub struct FakeEmbeddingProvider {
51 pub embedding_count: AtomicUsize,
52 pub credential_provider: NullCredentialProvider,
53}
54
55impl Clone for FakeEmbeddingProvider {
56 fn clone(&self) -> Self {
57 FakeEmbeddingProvider {
58 embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
59 credential_provider: self.credential_provider.clone(),
60 }
61 }
62}
63
64impl Default for FakeEmbeddingProvider {
65 fn default() -> Self {
66 FakeEmbeddingProvider {
67 embedding_count: AtomicUsize::default(),
68 credential_provider: NullCredentialProvider {},
69 }
70 }
71}
72
73impl FakeEmbeddingProvider {
74 pub fn embedding_count(&self) -> usize {
75 self.embedding_count.load(atomic::Ordering::SeqCst)
76 }
77
78 pub fn embed_sync(&self, span: &str) -> Embedding {
79 let mut result = vec![1.0; 26];
80 for letter in span.chars() {
81 let letter = letter.to_ascii_lowercase();
82 if letter as u32 >= 'a' as u32 {
83 let ix = (letter as u32) - ('a' as u32);
84 if ix < 26 {
85 result[ix as usize] += 1.0;
86 }
87 }
88 }
89
90 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
91 for x in &mut result {
92 *x /= norm;
93 }
94
95 result.into()
96 }
97}
98
99#[async_trait]
100impl EmbeddingProvider for FakeEmbeddingProvider {
101 fn base_model(&self) -> Box<dyn LanguageModel> {
102 Box::new(FakeLanguageModel { capacity: 1000 })
103 }
104 fn credential_provider(&self) -> Box<dyn CredentialProvider> {
105 let credential_provider: Box<dyn CredentialProvider> =
106 Box::new(self.credential_provider.clone());
107 credential_provider
108 }
109 fn max_tokens_per_batch(&self) -> usize {
110 1000
111 }
112
113 fn rate_limit_expiration(&self) -> Option<Instant> {
114 None
115 }
116
117 async fn embed_batch(
118 &self,
119 spans: Vec<String>,
120 _credential: ProviderCredential,
121 ) -> anyhow::Result<Vec<Embedding>> {
122 self.embedding_count
123 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
124
125 anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
126 }
127}