test.rs

  1use std::{
  2    sync::atomic::{self, AtomicUsize, Ordering},
  3    time::Instant,
  4};
  5
  6use async_trait::async_trait;
  7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  8use gpui::AppContext;
  9use parking_lot::Mutex;
 10
 11use crate::{
 12    auth::{CredentialProvider, ProviderCredential},
 13    completion::{CompletionProvider, CompletionRequest},
 14    embedding::{Embedding, EmbeddingProvider},
 15    models::{LanguageModel, TruncationDirection},
 16};
 17
 18#[derive(Clone)]
 19pub struct FakeLanguageModel {
 20    pub capacity: usize,
 21}
 22
 23impl LanguageModel for FakeLanguageModel {
 24    fn name(&self) -> String {
 25        "dummy".to_string()
 26    }
 27    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
 28        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
 29    }
 30    fn truncate(
 31        &self,
 32        content: &str,
 33        length: usize,
 34        direction: TruncationDirection,
 35    ) -> anyhow::Result<String> {
 36        println!("TRYING TO TRUNCATE: {:?}", length.clone());
 37
 38        if length > self.count_tokens(content)? {
 39            println!("NOT TRUNCATING");
 40            return anyhow::Ok(content.to_string());
 41        }
 42
 43        anyhow::Ok(match direction {
 44            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
 45                .into_iter()
 46                .collect::<String>(),
 47            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
 48                .into_iter()
 49                .collect::<String>(),
 50        })
 51    }
 52    fn capacity(&self) -> anyhow::Result<usize> {
 53        anyhow::Ok(self.capacity)
 54    }
 55}
 56
 57pub struct FakeEmbeddingProvider {
 58    pub embedding_count: AtomicUsize,
 59}
 60
 61impl Clone for FakeEmbeddingProvider {
 62    fn clone(&self) -> Self {
 63        FakeEmbeddingProvider {
 64            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
 65        }
 66    }
 67}
 68
 69impl Default for FakeEmbeddingProvider {
 70    fn default() -> Self {
 71        FakeEmbeddingProvider {
 72            embedding_count: AtomicUsize::default(),
 73        }
 74    }
 75}
 76
 77impl FakeEmbeddingProvider {
 78    pub fn embedding_count(&self) -> usize {
 79        self.embedding_count.load(atomic::Ordering::SeqCst)
 80    }
 81
 82    pub fn embed_sync(&self, span: &str) -> Embedding {
 83        let mut result = vec![1.0; 26];
 84        for letter in span.chars() {
 85            let letter = letter.to_ascii_lowercase();
 86            if letter as u32 >= 'a' as u32 {
 87                let ix = (letter as u32) - ('a' as u32);
 88                if ix < 26 {
 89                    result[ix as usize] += 1.0;
 90                }
 91            }
 92        }
 93
 94        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
 95        for x in &mut result {
 96            *x /= norm;
 97        }
 98
 99        result.into()
100    }
101}
102
103impl CredentialProvider for FakeEmbeddingProvider {
104    fn has_credentials(&self) -> bool {
105        true
106    }
107    fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
108        ProviderCredential::NotNeeded
109    }
110    fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
111    fn delete_credentials(&self, _cx: &AppContext) {}
112}
113
114#[async_trait]
115impl EmbeddingProvider for FakeEmbeddingProvider {
116    fn base_model(&self) -> Box<dyn LanguageModel> {
117        Box::new(FakeLanguageModel { capacity: 1000 })
118    }
119    fn max_tokens_per_batch(&self) -> usize {
120        1000
121    }
122
123    fn rate_limit_expiration(&self) -> Option<Instant> {
124        None
125    }
126
127    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
128        self.embedding_count
129            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
130
131        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
132    }
133}
134
135pub struct FakeCompletionProvider {
136    last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
137}
138
139impl Clone for FakeCompletionProvider {
140    fn clone(&self) -> Self {
141        Self {
142            last_completion_tx: Mutex::new(None),
143        }
144    }
145}
146
147impl FakeCompletionProvider {
148    pub fn new() -> Self {
149        Self {
150            last_completion_tx: Mutex::new(None),
151        }
152    }
153
154    pub fn send_completion(&self, completion: impl Into<String>) {
155        let mut tx = self.last_completion_tx.lock();
156
157        println!("COMPLETION TX: {:?}", &tx);
158
159        let a = tx.as_mut().unwrap();
160        a.try_send(completion.into()).unwrap();
161
162        // tx.as_mut().unwrap().try_send(completion.into()).unwrap();
163    }
164
165    pub fn finish_completion(&self) {
166        println!("FINISHING COMPLETION");
167        self.last_completion_tx.lock().take().unwrap();
168    }
169}
170
171impl CredentialProvider for FakeCompletionProvider {
172    fn has_credentials(&self) -> bool {
173        true
174    }
175    fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
176        ProviderCredential::NotNeeded
177    }
178    fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
179    fn delete_credentials(&self, _cx: &AppContext) {}
180}
181
182impl CompletionProvider for FakeCompletionProvider {
183    fn base_model(&self) -> Box<dyn LanguageModel> {
184        let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
185        model
186    }
187    fn complete(
188        &self,
189        _prompt: Box<dyn CompletionRequest>,
190    ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
191        println!("COMPLETING");
192        let (tx, rx) = mpsc::channel(1);
193        *self.last_completion_tx.lock() = Some(tx);
194        println!("TX: {:?}", *self.last_completion_tx.lock());
195        async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
196    }
197    fn box_clone(&self) -> Box<dyn CompletionProvider> {
198        Box::new((*self).clone())
199    }
200}