test.rs

  1use std::{
  2    sync::atomic::{self, AtomicUsize, Ordering},
  3    time::Instant,
  4};
  5
  6use async_trait::async_trait;
  7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  8use gpui::AppContext;
  9use parking_lot::Mutex;
 10
 11use crate::{
 12    auth::{CredentialProvider, ProviderCredential},
 13    completion::{CompletionProvider, CompletionRequest},
 14    embedding::{Embedding, EmbeddingProvider},
 15    models::{LanguageModel, TruncationDirection},
 16};
 17
 18#[derive(Clone)]
 19pub struct FakeLanguageModel {
 20    pub capacity: usize,
 21}
 22
 23impl LanguageModel for FakeLanguageModel {
 24    fn name(&self) -> String {
 25        "dummy".to_string()
 26    }
 27    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
 28        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
 29    }
 30    fn truncate(
 31        &self,
 32        content: &str,
 33        length: usize,
 34        direction: TruncationDirection,
 35    ) -> anyhow::Result<String> {
 36        println!("TRYING TO TRUNCATE: {:?}", length.clone());
 37
 38        if length > self.count_tokens(content)? {
 39            println!("NOT TRUNCATING");
 40            return anyhow::Ok(content.to_string());
 41        }
 42
 43        anyhow::Ok(match direction {
 44            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
 45                .into_iter()
 46                .collect::<String>(),
 47            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
 48                .into_iter()
 49                .collect::<String>(),
 50        })
 51    }
 52    fn capacity(&self) -> anyhow::Result<usize> {
 53        anyhow::Ok(self.capacity)
 54    }
 55}
 56
 57pub struct FakeEmbeddingProvider {
 58    pub embedding_count: AtomicUsize,
 59}
 60
 61impl Clone for FakeEmbeddingProvider {
 62    fn clone(&self) -> Self {
 63        FakeEmbeddingProvider {
 64            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
 65        }
 66    }
 67}
 68
 69impl Default for FakeEmbeddingProvider {
 70    fn default() -> Self {
 71        FakeEmbeddingProvider {
 72            embedding_count: AtomicUsize::default(),
 73        }
 74    }
 75}
 76
 77impl FakeEmbeddingProvider {
 78    pub fn embedding_count(&self) -> usize {
 79        self.embedding_count.load(atomic::Ordering::SeqCst)
 80    }
 81
 82    pub fn embed_sync(&self, span: &str) -> Embedding {
 83        let mut result = vec![1.0; 26];
 84        for letter in span.chars() {
 85            let letter = letter.to_ascii_lowercase();
 86            if letter as u32 >= 'a' as u32 {
 87                let ix = (letter as u32) - ('a' as u32);
 88                if ix < 26 {
 89                    result[ix as usize] += 1.0;
 90                }
 91            }
 92        }
 93
 94        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
 95        for x in &mut result {
 96            *x /= norm;
 97        }
 98
 99        result.into()
100    }
101}
102
103impl CredentialProvider for FakeEmbeddingProvider {
104    fn has_credentials(&self) -> bool {
105        true
106    }
107
108    fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
109        async { ProviderCredential::NotNeeded }.boxed()
110    }
111
112    fn save_credentials(
113        &self,
114        _cx: &mut AppContext,
115        _credential: ProviderCredential,
116    ) -> BoxFuture<()> {
117        async {}.boxed()
118    }
119
120    fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
121        async {}.boxed()
122    }
123}
124
125#[async_trait]
126impl EmbeddingProvider for FakeEmbeddingProvider {
127    fn base_model(&self) -> Box<dyn LanguageModel> {
128        Box::new(FakeLanguageModel { capacity: 1000 })
129    }
130    fn max_tokens_per_batch(&self) -> usize {
131        1000
132    }
133
134    fn rate_limit_expiration(&self) -> Option<Instant> {
135        None
136    }
137
138    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
139        self.embedding_count
140            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
141
142        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
143    }
144}
145
146pub struct FakeCompletionProvider {
147    last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
148}
149
150impl Clone for FakeCompletionProvider {
151    fn clone(&self) -> Self {
152        Self {
153            last_completion_tx: Mutex::new(None),
154        }
155    }
156}
157
158impl FakeCompletionProvider {
159    pub fn new() -> Self {
160        Self {
161            last_completion_tx: Mutex::new(None),
162        }
163    }
164
165    pub fn send_completion(&self, completion: impl Into<String>) {
166        let mut tx = self.last_completion_tx.lock();
167        tx.as_mut().unwrap().try_send(completion.into()).unwrap();
168    }
169
170    pub fn finish_completion(&self) {
171        self.last_completion_tx.lock().take().unwrap();
172    }
173}
174
175impl CredentialProvider for FakeCompletionProvider {
176    fn has_credentials(&self) -> bool {
177        true
178    }
179
180    fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
181        async { ProviderCredential::NotNeeded }.boxed()
182    }
183
184    fn save_credentials(
185        &self,
186        _cx: &mut AppContext,
187        _credential: ProviderCredential,
188    ) -> BoxFuture<()> {
189        async {}.boxed()
190    }
191
192    fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
193        async {}.boxed()
194    }
195}
196
197impl CompletionProvider for FakeCompletionProvider {
198    fn base_model(&self) -> Box<dyn LanguageModel> {
199        let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
200        model
201    }
202    fn complete(
203        &self,
204        _prompt: Box<dyn CompletionRequest>,
205    ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
206        let (tx, rx) = mpsc::channel(1);
207        *self.last_completion_tx.lock() = Some(tx);
208        async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
209    }
210    fn box_clone(&self) -> Box<dyn CompletionProvider> {
211        Box::new((*self).clone())
212    }
213}