1use std::{
2 sync::atomic::{self, AtomicUsize, Ordering},
3 time::Instant,
4};
5
6use async_trait::async_trait;
7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
8use gpui::AppContext;
9use parking_lot::Mutex;
10
11use crate::{
12 auth::{CredentialProvider, ProviderCredential},
13 completion::{CompletionProvider, CompletionRequest},
14 embedding::{Embedding, EmbeddingProvider},
15 models::{LanguageModel, TruncationDirection},
16};
17
18#[derive(Clone)]
19pub struct FakeLanguageModel {
20 pub capacity: usize,
21}
22
23impl LanguageModel for FakeLanguageModel {
24 fn name(&self) -> String {
25 "dummy".to_string()
26 }
27 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
28 anyhow::Ok(content.chars().collect::<Vec<char>>().len())
29 }
30 fn truncate(
31 &self,
32 content: &str,
33 length: usize,
34 direction: TruncationDirection,
35 ) -> anyhow::Result<String> {
36 println!("TRYING TO TRUNCATE: {:?}", length.clone());
37
38 if length > self.count_tokens(content)? {
39 println!("NOT TRUNCATING");
40 return anyhow::Ok(content.to_string());
41 }
42
43 anyhow::Ok(match direction {
44 TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
45 .into_iter()
46 .collect::<String>(),
47 TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
48 .into_iter()
49 .collect::<String>(),
50 })
51 }
52 fn capacity(&self) -> anyhow::Result<usize> {
53 anyhow::Ok(self.capacity)
54 }
55}
56
57pub struct FakeEmbeddingProvider {
58 pub embedding_count: AtomicUsize,
59}
60
61impl Clone for FakeEmbeddingProvider {
62 fn clone(&self) -> Self {
63 FakeEmbeddingProvider {
64 embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
65 }
66 }
67}
68
69impl Default for FakeEmbeddingProvider {
70 fn default() -> Self {
71 FakeEmbeddingProvider {
72 embedding_count: AtomicUsize::default(),
73 }
74 }
75}
76
77impl FakeEmbeddingProvider {
78 pub fn embedding_count(&self) -> usize {
79 self.embedding_count.load(atomic::Ordering::SeqCst)
80 }
81
82 pub fn embed_sync(&self, span: &str) -> Embedding {
83 let mut result = vec![1.0; 26];
84 for letter in span.chars() {
85 let letter = letter.to_ascii_lowercase();
86 if letter as u32 >= 'a' as u32 {
87 let ix = (letter as u32) - ('a' as u32);
88 if ix < 26 {
89 result[ix as usize] += 1.0;
90 }
91 }
92 }
93
94 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
95 for x in &mut result {
96 *x /= norm;
97 }
98
99 result.into()
100 }
101}
102
103impl CredentialProvider for FakeEmbeddingProvider {
104 fn has_credentials(&self) -> bool {
105 true
106 }
107
108 fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
109 async { ProviderCredential::NotNeeded }.boxed()
110 }
111
112 fn save_credentials(
113 &self,
114 _cx: &mut AppContext,
115 _credential: ProviderCredential,
116 ) -> BoxFuture<()> {
117 async {}.boxed()
118 }
119
120 fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
121 async {}.boxed()
122 }
123}
124
125#[async_trait]
126impl EmbeddingProvider for FakeEmbeddingProvider {
127 fn base_model(&self) -> Box<dyn LanguageModel> {
128 Box::new(FakeLanguageModel { capacity: 1000 })
129 }
130 fn max_tokens_per_batch(&self) -> usize {
131 1000
132 }
133
134 fn rate_limit_expiration(&self) -> Option<Instant> {
135 None
136 }
137
138 async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
139 self.embedding_count
140 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
141
142 anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
143 }
144}
145
146pub struct FakeCompletionProvider {
147 last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
148}
149
150impl Clone for FakeCompletionProvider {
151 fn clone(&self) -> Self {
152 Self {
153 last_completion_tx: Mutex::new(None),
154 }
155 }
156}
157
158impl FakeCompletionProvider {
159 pub fn new() -> Self {
160 Self {
161 last_completion_tx: Mutex::new(None),
162 }
163 }
164
165 pub fn send_completion(&self, completion: impl Into<String>) {
166 let mut tx = self.last_completion_tx.lock();
167 tx.as_mut().unwrap().try_send(completion.into()).unwrap();
168 }
169
170 pub fn finish_completion(&self) {
171 self.last_completion_tx.lock().take().unwrap();
172 }
173}
174
175impl CredentialProvider for FakeCompletionProvider {
176 fn has_credentials(&self) -> bool {
177 true
178 }
179
180 fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
181 async { ProviderCredential::NotNeeded }.boxed()
182 }
183
184 fn save_credentials(
185 &self,
186 _cx: &mut AppContext,
187 _credential: ProviderCredential,
188 ) -> BoxFuture<()> {
189 async {}.boxed()
190 }
191
192 fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
193 async {}.boxed()
194 }
195}
196
197impl CompletionProvider for FakeCompletionProvider {
198 fn base_model(&self) -> Box<dyn LanguageModel> {
199 let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
200 model
201 }
202 fn complete(
203 &self,
204 _prompt: Box<dyn CompletionRequest>,
205 ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
206 let (tx, rx) = mpsc::channel(1);
207 *self.last_completion_tx.lock() = Some(tx);
208 async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
209 }
210 fn box_clone(&self) -> Box<dyn CompletionProvider> {
211 Box::new((*self).clone())
212 }
213}