1use std::{
2 sync::atomic::{self, AtomicUsize, Ordering},
3 time::Instant,
4};
5
6use async_trait::async_trait;
7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
8use gpui::AppContext;
9use parking_lot::Mutex;
10
11use crate::{
12 auth::{CredentialProvider, ProviderCredential},
13 completion::{CompletionProvider, CompletionRequest},
14 embedding::{Embedding, EmbeddingProvider},
15 models::{LanguageModel, TruncationDirection},
16};
17
18#[derive(Clone)]
19pub struct FakeLanguageModel {
20 pub capacity: usize,
21}
22
23impl LanguageModel for FakeLanguageModel {
24 fn name(&self) -> String {
25 "dummy".to_string()
26 }
27 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
28 anyhow::Ok(content.chars().collect::<Vec<char>>().len())
29 }
30 fn truncate(
31 &self,
32 content: &str,
33 length: usize,
34 direction: TruncationDirection,
35 ) -> anyhow::Result<String> {
36 println!("TRYING TO TRUNCATE: {:?}", length.clone());
37
38 if length > self.count_tokens(content)? {
39 println!("NOT TRUNCATING");
40 return anyhow::Ok(content.to_string());
41 }
42
43 anyhow::Ok(match direction {
44 TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
45 .into_iter()
46 .collect::<String>(),
47 TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
48 .into_iter()
49 .collect::<String>(),
50 })
51 }
52 fn capacity(&self) -> anyhow::Result<usize> {
53 anyhow::Ok(self.capacity)
54 }
55}
56
57#[derive(Default)]
58pub struct FakeEmbeddingProvider {
59 pub embedding_count: AtomicUsize,
60}
61
62impl Clone for FakeEmbeddingProvider {
63 fn clone(&self) -> Self {
64 FakeEmbeddingProvider {
65 embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
66 }
67 }
68}
69
70impl FakeEmbeddingProvider {
71 pub fn embedding_count(&self) -> usize {
72 self.embedding_count.load(atomic::Ordering::SeqCst)
73 }
74
75 pub fn embed_sync(&self, span: &str) -> Embedding {
76 let mut result = vec![1.0; 26];
77 for letter in span.chars() {
78 let letter = letter.to_ascii_lowercase();
79 if letter as u32 >= 'a' as u32 {
80 let ix = (letter as u32) - ('a' as u32);
81 if ix < 26 {
82 result[ix as usize] += 1.0;
83 }
84 }
85 }
86
87 let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
88 for x in &mut result {
89 *x /= norm;
90 }
91
92 result.into()
93 }
94}
95
96impl CredentialProvider for FakeEmbeddingProvider {
97 fn has_credentials(&self) -> bool {
98 true
99 }
100
101 fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
102 async { ProviderCredential::NotNeeded }.boxed()
103 }
104
105 fn save_credentials(
106 &self,
107 _cx: &mut AppContext,
108 _credential: ProviderCredential,
109 ) -> BoxFuture<()> {
110 async {}.boxed()
111 }
112
113 fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
114 async {}.boxed()
115 }
116}
117
118#[async_trait]
119impl EmbeddingProvider for FakeEmbeddingProvider {
120 fn base_model(&self) -> Box<dyn LanguageModel> {
121 Box::new(FakeLanguageModel { capacity: 1000 })
122 }
123 fn max_tokens_per_batch(&self) -> usize {
124 1000
125 }
126
127 fn rate_limit_expiration(&self) -> Option<Instant> {
128 None
129 }
130
131 async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
132 self.embedding_count
133 .fetch_add(spans.len(), atomic::Ordering::SeqCst);
134
135 anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
136 }
137}
138
139pub struct FakeCompletionProvider {
140 last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
141}
142
143impl Clone for FakeCompletionProvider {
144 fn clone(&self) -> Self {
145 Self {
146 last_completion_tx: Mutex::new(None),
147 }
148 }
149}
150
151impl FakeCompletionProvider {
152 pub fn new() -> Self {
153 Self {
154 last_completion_tx: Mutex::new(None),
155 }
156 }
157
158 pub fn send_completion(&self, completion: impl Into<String>) {
159 let mut tx = self.last_completion_tx.lock();
160 tx.as_mut().unwrap().try_send(completion.into()).unwrap();
161 }
162
163 pub fn finish_completion(&self) {
164 self.last_completion_tx.lock().take().unwrap();
165 }
166}
167
168impl CredentialProvider for FakeCompletionProvider {
169 fn has_credentials(&self) -> bool {
170 true
171 }
172
173 fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
174 async { ProviderCredential::NotNeeded }.boxed()
175 }
176
177 fn save_credentials(
178 &self,
179 _cx: &mut AppContext,
180 _credential: ProviderCredential,
181 ) -> BoxFuture<()> {
182 async {}.boxed()
183 }
184
185 fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
186 async {}.boxed()
187 }
188}
189
190impl CompletionProvider for FakeCompletionProvider {
191 fn base_model(&self) -> Box<dyn LanguageModel> {
192 let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
193 model
194 }
195 fn complete(
196 &self,
197 _prompt: Box<dyn CompletionRequest>,
198 ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
199 let (tx, rx) = mpsc::channel(1);
200 *self.last_completion_tx.lock() = Some(tx);
201 async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
202 }
203 fn box_clone(&self) -> Box<dyn CompletionProvider> {
204 Box::new((*self).clone())
205 }
206}