1use std::{
2 sync::atomic::{self, AtomicUsize, Ordering},
3 time::Instant,
4};
5
6use async_trait::async_trait;
7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
8use parking_lot::Mutex;
9
10use crate::{
11 auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
12 completion::{CompletionProvider, CompletionRequest},
13 embedding::{Embedding, EmbeddingProvider},
14 models::{LanguageModel, TruncationDirection},
15};
16
17#[derive(Clone)]
18pub struct FakeLanguageModel {
19 pub capacity: usize,
20}
21
22impl LanguageModel for FakeLanguageModel {
23 fn name(&self) -> String {
24 "dummy".to_string()
25 }
26 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
27 anyhow::Ok(content.chars().collect::<Vec<char>>().len())
28 }
29 fn truncate(
30 &self,
31 content: &str,
32 length: usize,
33 direction: TruncationDirection,
34 ) -> anyhow::Result<String> {
35 if length > self.count_tokens(content)? {
36 return anyhow::Ok(content.to_string());
37 }
38
39 anyhow::Ok(match direction {
40 TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
41 .into_iter()
42 .collect::<String>(),
43 TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
44 .into_iter()
45 .collect::<String>(),
46 })
47 }
48 fn capacity(&self) -> anyhow::Result<usize> {
49 anyhow::Ok(self.capacity)
50 }
51}
52
53pub struct FakeEmbeddingProvider {
54 pub embedding_count: AtomicUsize,
55 pub credential_provider: NullCredentialProvider,
56}
57
58impl Clone for FakeEmbeddingProvider {
59 fn clone(&self) -> Self {
60 FakeEmbeddingProvider {
61 embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
62 credential_provider: self.credential_provider.clone(),
63 }
64 }
65}
66
67impl Default for FakeEmbeddingProvider {
68 fn default() -> Self {
69 FakeEmbeddingProvider {
70 embedding_count: AtomicUsize::default(),
71 credential_provider: NullCredentialProvider {},
72 }
73 }
74}
75
76impl FakeEmbeddingProvider {
77 pub fn embedding_count(&self) -> usize {
78 self.embedding_count.load(atomic::Ordering::SeqCst)
79 }
80
81 pub fn embed_sync(&self, span: &str) -> Embedding {
82 let mut result = vec![1.0; 26];
83 for letter in span.chars() {
84 let letter = letter.to_ascii_lowercase();
85 if letter as u32 >= 'a' as u32 {
86 let ix = (letter as u32) - ('a' as u32);
87 if ix < 26 {
88 result[ix as usize] += 1.0;
89 }
90 }
91 }
92
93 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
94 for x in &mut result {
95 *x /= norm;
96 }
97
98 result.into()
99 }
100}
101
102#[async_trait]
103impl EmbeddingProvider for FakeEmbeddingProvider {
104 fn base_model(&self) -> Box<dyn LanguageModel> {
105 Box::new(FakeLanguageModel { capacity: 1000 })
106 }
107 fn credential_provider(&self) -> Box<dyn CredentialProvider> {
108 let credential_provider: Box<dyn CredentialProvider> =
109 Box::new(self.credential_provider.clone());
110 credential_provider
111 }
112 fn max_tokens_per_batch(&self) -> usize {
113 1000
114 }
115
116 fn rate_limit_expiration(&self) -> Option<Instant> {
117 None
118 }
119
120 async fn embed_batch(
121 &self,
122 spans: Vec<String>,
123 _credential: ProviderCredential,
124 ) -> anyhow::Result<Vec<Embedding>> {
125 self.embedding_count
126 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
127
128 anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
129 }
130}
131
132pub struct TestCompletionProvider {
133 last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
134}
135
136impl TestCompletionProvider {
137 pub fn new() -> Self {
138 Self {
139 last_completion_tx: Mutex::new(None),
140 }
141 }
142
143 pub fn send_completion(&self, completion: impl Into<String>) {
144 let mut tx = self.last_completion_tx.lock();
145 tx.as_mut().unwrap().try_send(completion.into()).unwrap();
146 }
147
148 pub fn finish_completion(&self) {
149 self.last_completion_tx.lock().take().unwrap();
150 }
151}
152
153impl CompletionProvider for TestCompletionProvider {
154 fn base_model(&self) -> Box<dyn LanguageModel> {
155 let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
156 model
157 }
158 fn complete(
159 &self,
160 _prompt: Box<dyn CompletionRequest>,
161 ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
162 let (tx, rx) = mpsc::channel(1);
163 *self.last_completion_tx.lock() = Some(tx);
164 async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
165 }
166}