1use std::{
2 sync::atomic::{self, AtomicUsize, Ordering},
3 time::Instant,
4};
5
6use async_trait::async_trait;
7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
8use gpui::AppContext;
9use parking_lot::Mutex;
10
11use crate::{
12 auth::{CredentialProvider, ProviderCredential},
13 completion::{CompletionProvider, CompletionRequest},
14 embedding::{Embedding, EmbeddingProvider},
15 models::{LanguageModel, TruncationDirection},
16};
17
18#[derive(Clone)]
19pub struct FakeLanguageModel {
20 pub capacity: usize,
21}
22
23impl LanguageModel for FakeLanguageModel {
24 fn name(&self) -> String {
25 "dummy".to_string()
26 }
27 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
28 anyhow::Ok(content.chars().collect::<Vec<char>>().len())
29 }
30 fn truncate(
31 &self,
32 content: &str,
33 length: usize,
34 direction: TruncationDirection,
35 ) -> anyhow::Result<String> {
36 println!("TRYING TO TRUNCATE: {:?}", length.clone());
37
38 if length > self.count_tokens(content)? {
39 println!("NOT TRUNCATING");
40 return anyhow::Ok(content.to_string());
41 }
42
43 anyhow::Ok(match direction {
44 TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
45 .into_iter()
46 .collect::<String>(),
47 TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
48 .into_iter()
49 .collect::<String>(),
50 })
51 }
52 fn capacity(&self) -> anyhow::Result<usize> {
53 anyhow::Ok(self.capacity)
54 }
55}
56
57pub struct FakeEmbeddingProvider {
58 pub embedding_count: AtomicUsize,
59}
60
61impl Clone for FakeEmbeddingProvider {
62 fn clone(&self) -> Self {
63 FakeEmbeddingProvider {
64 embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
65 }
66 }
67}
68
69impl Default for FakeEmbeddingProvider {
70 fn default() -> Self {
71 FakeEmbeddingProvider {
72 embedding_count: AtomicUsize::default(),
73 }
74 }
75}
76
77impl FakeEmbeddingProvider {
78 pub fn embedding_count(&self) -> usize {
79 self.embedding_count.load(atomic::Ordering::SeqCst)
80 }
81
82 pub fn embed_sync(&self, span: &str) -> Embedding {
83 let mut result = vec![1.0; 26];
84 for letter in span.chars() {
85 let letter = letter.to_ascii_lowercase();
86 if letter as u32 >= 'a' as u32 {
87 let ix = (letter as u32) - ('a' as u32);
88 if ix < 26 {
89 result[ix as usize] += 1.0;
90 }
91 }
92 }
93
94 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
95 for x in &mut result {
96 *x /= norm;
97 }
98
99 result.into()
100 }
101}
102
103impl CredentialProvider for FakeEmbeddingProvider {
104 fn has_credentials(&self) -> bool {
105 true
106 }
107 fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
108 ProviderCredential::NotNeeded
109 }
110 fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
111 fn delete_credentials(&self, _cx: &mut AppContext) {}
112}
113
114#[async_trait]
115impl EmbeddingProvider for FakeEmbeddingProvider {
116 fn base_model(&self) -> Box<dyn LanguageModel> {
117 Box::new(FakeLanguageModel { capacity: 1000 })
118 }
119 fn max_tokens_per_batch(&self) -> usize {
120 1000
121 }
122
123 fn rate_limit_expiration(&self) -> Option<Instant> {
124 None
125 }
126
127 async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
128 self.embedding_count
129 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
130
131 anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
132 }
133}
134
135pub struct FakeCompletionProvider {
136 last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
137}
138
139impl Clone for FakeCompletionProvider {
140 fn clone(&self) -> Self {
141 Self {
142 last_completion_tx: Mutex::new(None),
143 }
144 }
145}
146
147impl FakeCompletionProvider {
148 pub fn new() -> Self {
149 Self {
150 last_completion_tx: Mutex::new(None),
151 }
152 }
153
154 pub fn send_completion(&self, completion: impl Into<String>) {
155 let mut tx = self.last_completion_tx.lock();
156 tx.as_mut().unwrap().try_send(completion.into()).unwrap();
157 }
158
159 pub fn finish_completion(&self) {
160 self.last_completion_tx.lock().take().unwrap();
161 }
162}
163
164impl CredentialProvider for FakeCompletionProvider {
165 fn has_credentials(&self) -> bool {
166 true
167 }
168 fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
169 ProviderCredential::NotNeeded
170 }
171 fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
172 fn delete_credentials(&self, _cx: &mut AppContext) {}
173}
174
175impl CompletionProvider for FakeCompletionProvider {
176 fn base_model(&self) -> Box<dyn LanguageModel> {
177 let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
178 model
179 }
180 fn complete(
181 &self,
182 _prompt: Box<dyn CompletionRequest>,
183 ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
184 let (tx, rx) = mpsc::channel(1);
185 *self.last_completion_tx.lock() = Some(tx);
186 async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
187 }
188 fn box_clone(&self) -> Box<dyn CompletionProvider> {
189 Box::new((*self).clone())
190 }
191}