test.rs

  1use std::{
  2    sync::atomic::{self, AtomicUsize, Ordering},
  3    time::Instant,
  4};
  5
  6use async_trait::async_trait;
  7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  8use gpui::AppContext;
  9use parking_lot::Mutex;
 10
 11use crate::{
 12    auth::{CredentialProvider, ProviderCredential},
 13    completion::{CompletionProvider, CompletionRequest},
 14    embedding::{Embedding, EmbeddingProvider},
 15    models::{LanguageModel, TruncationDirection},
 16};
 17
 18#[derive(Clone)]
 19pub struct FakeLanguageModel {
 20    pub capacity: usize,
 21}
 22
 23impl LanguageModel for FakeLanguageModel {
 24    fn name(&self) -> String {
 25        "dummy".to_string()
 26    }
 27    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
 28        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
 29    }
 30    fn truncate(
 31        &self,
 32        content: &str,
 33        length: usize,
 34        direction: TruncationDirection,
 35    ) -> anyhow::Result<String> {
 36        println!("TRYING TO TRUNCATE: {:?}", length.clone());
 37
 38        if length > self.count_tokens(content)? {
 39            println!("NOT TRUNCATING");
 40            return anyhow::Ok(content.to_string());
 41        }
 42
 43        anyhow::Ok(match direction {
 44            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
 45                .into_iter()
 46                .collect::<String>(),
 47            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
 48                .into_iter()
 49                .collect::<String>(),
 50        })
 51    }
 52    fn capacity(&self) -> anyhow::Result<usize> {
 53        anyhow::Ok(self.capacity)
 54    }
 55}
 56
 57pub struct FakeEmbeddingProvider {
 58    pub embedding_count: AtomicUsize,
 59}
 60
 61impl Clone for FakeEmbeddingProvider {
 62    fn clone(&self) -> Self {
 63        FakeEmbeddingProvider {
 64            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
 65        }
 66    }
 67}
 68
 69impl Default for FakeEmbeddingProvider {
 70    fn default() -> Self {
 71        FakeEmbeddingProvider {
 72            embedding_count: AtomicUsize::default(),
 73        }
 74    }
 75}
 76
 77impl FakeEmbeddingProvider {
 78    pub fn embedding_count(&self) -> usize {
 79        self.embedding_count.load(atomic::Ordering::SeqCst)
 80    }
 81
 82    pub fn embed_sync(&self, span: &str) -> Embedding {
 83        let mut result = vec![1.0; 26];
 84        for letter in span.chars() {
 85            let letter = letter.to_ascii_lowercase();
 86            if letter as u32 >= 'a' as u32 {
 87                let ix = (letter as u32) - ('a' as u32);
 88                if ix < 26 {
 89                    result[ix as usize] += 1.0;
 90                }
 91            }
 92        }
 93
 94        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
 95        for x in &mut result {
 96            *x /= norm;
 97        }
 98
 99        result.into()
100    }
101}
102
103impl CredentialProvider for FakeEmbeddingProvider {
104    fn has_credentials(&self) -> bool {
105        true
106    }
107    fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
108        ProviderCredential::NotNeeded
109    }
110    fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
111    fn delete_credentials(&self, _cx: &mut AppContext) {}
112}
113
114#[async_trait]
115impl EmbeddingProvider for FakeEmbeddingProvider {
116    fn base_model(&self) -> Box<dyn LanguageModel> {
117        Box::new(FakeLanguageModel { capacity: 1000 })
118    }
119    fn max_tokens_per_batch(&self) -> usize {
120        1000
121    }
122
123    fn rate_limit_expiration(&self) -> Option<Instant> {
124        None
125    }
126
127    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
128        self.embedding_count
129            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
130
131        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
132    }
133}
134
135pub struct FakeCompletionProvider {
136    last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
137}
138
139impl Clone for FakeCompletionProvider {
140    fn clone(&self) -> Self {
141        Self {
142            last_completion_tx: Mutex::new(None),
143        }
144    }
145}
146
147impl FakeCompletionProvider {
148    pub fn new() -> Self {
149        Self {
150            last_completion_tx: Mutex::new(None),
151        }
152    }
153
154    pub fn send_completion(&self, completion: impl Into<String>) {
155        let mut tx = self.last_completion_tx.lock();
156        tx.as_mut().unwrap().try_send(completion.into()).unwrap();
157    }
158
159    pub fn finish_completion(&self) {
160        self.last_completion_tx.lock().take().unwrap();
161    }
162}
163
164impl CredentialProvider for FakeCompletionProvider {
165    fn has_credentials(&self) -> bool {
166        true
167    }
168    fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
169        ProviderCredential::NotNeeded
170    }
171    fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
172    fn delete_credentials(&self, _cx: &mut AppContext) {}
173}
174
175impl CompletionProvider for FakeCompletionProvider {
176    fn base_model(&self) -> Box<dyn LanguageModel> {
177        let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
178        model
179    }
180    fn complete(
181        &self,
182        _prompt: Box<dyn CompletionRequest>,
183    ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
184        let (tx, rx) = mpsc::channel(1);
185        *self.last_completion_tx.lock() = Some(tx);
186        async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
187    }
188    fn box_clone(&self) -> Box<dyn CompletionProvider> {
189        Box::new((*self).clone())
190    }
191}