test.rs

  1use std::{
  2    sync::atomic::{self, AtomicUsize, Ordering},
  3    time::Instant,
  4};
  5
  6use async_trait::async_trait;
  7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  8use gpui2::AppContext;
  9use parking_lot::Mutex;
 10
 11use crate::{
 12    auth::{CredentialProvider, ProviderCredential},
 13    completion::{CompletionProvider, CompletionRequest},
 14    embedding::{Embedding, EmbeddingProvider},
 15    models::{LanguageModel, TruncationDirection},
 16};
 17
 18#[derive(Clone)]
 19pub struct FakeLanguageModel {
 20    pub capacity: usize,
 21}
 22
 23impl LanguageModel for FakeLanguageModel {
 24    fn name(&self) -> String {
 25        "dummy".to_string()
 26    }
 27    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
 28        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
 29    }
 30    fn truncate(
 31        &self,
 32        content: &str,
 33        length: usize,
 34        direction: TruncationDirection,
 35    ) -> anyhow::Result<String> {
 36        println!("TRYING TO TRUNCATE: {:?}", length.clone());
 37
 38        if length > self.count_tokens(content)? {
 39            println!("NOT TRUNCATING");
 40            return anyhow::Ok(content.to_string());
 41        }
 42
 43        anyhow::Ok(match direction {
 44            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
 45                .into_iter()
 46                .collect::<String>(),
 47            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
 48                .into_iter()
 49                .collect::<String>(),
 50        })
 51    }
 52    fn capacity(&self) -> anyhow::Result<usize> {
 53        anyhow::Ok(self.capacity)
 54    }
 55}
 56
 57pub struct FakeEmbeddingProvider {
 58    pub embedding_count: AtomicUsize,
 59}
 60
 61impl Clone for FakeEmbeddingProvider {
 62    fn clone(&self) -> Self {
 63        FakeEmbeddingProvider {
 64            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
 65        }
 66    }
 67}
 68
 69impl Default for FakeEmbeddingProvider {
 70    fn default() -> Self {
 71        FakeEmbeddingProvider {
 72            embedding_count: AtomicUsize::default(),
 73        }
 74    }
 75}
 76
 77impl FakeEmbeddingProvider {
 78    pub fn embedding_count(&self) -> usize {
 79        self.embedding_count.load(atomic::Ordering::SeqCst)
 80    }
 81
 82    pub fn embed_sync(&self, span: &str) -> Embedding {
 83        let mut result = vec![1.0; 26];
 84        for letter in span.chars() {
 85            let letter = letter.to_ascii_lowercase();
 86            if letter as u32 >= 'a' as u32 {
 87                let ix = (letter as u32) - ('a' as u32);
 88                if ix < 26 {
 89                    result[ix as usize] += 1.0;
 90                }
 91            }
 92        }
 93
 94        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
 95        for x in &mut result {
 96            *x /= norm;
 97        }
 98
 99        result.into()
100    }
101}
102
103#[async_trait]
104impl CredentialProvider for FakeEmbeddingProvider {
105    fn has_credentials(&self) -> bool {
106        true
107    }
108    async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
109        ProviderCredential::NotNeeded
110    }
111    async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
112    async fn delete_credentials(&self, _cx: &mut AppContext) {}
113}
114
115#[async_trait]
116impl EmbeddingProvider for FakeEmbeddingProvider {
117    fn base_model(&self) -> Box<dyn LanguageModel> {
118        Box::new(FakeLanguageModel { capacity: 1000 })
119    }
120    fn max_tokens_per_batch(&self) -> usize {
121        1000
122    }
123
124    fn rate_limit_expiration(&self) -> Option<Instant> {
125        None
126    }
127
128    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
129        self.embedding_count
130            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
131
132        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
133    }
134}
135
136pub struct FakeCompletionProvider {
137    last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
138}
139
140impl Clone for FakeCompletionProvider {
141    fn clone(&self) -> Self {
142        Self {
143            last_completion_tx: Mutex::new(None),
144        }
145    }
146}
147
148impl FakeCompletionProvider {
149    pub fn new() -> Self {
150        Self {
151            last_completion_tx: Mutex::new(None),
152        }
153    }
154
155    pub fn send_completion(&self, completion: impl Into<String>) {
156        let mut tx = self.last_completion_tx.lock();
157        tx.as_mut().unwrap().try_send(completion.into()).unwrap();
158    }
159
160    pub fn finish_completion(&self) {
161        self.last_completion_tx.lock().take().unwrap();
162    }
163}
164
165#[async_trait]
166impl CredentialProvider for FakeCompletionProvider {
167    fn has_credentials(&self) -> bool {
168        true
169    }
170    async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
171        ProviderCredential::NotNeeded
172    }
173    async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
174    async fn delete_credentials(&self, _cx: &mut AppContext) {}
175}
176
177impl CompletionProvider for FakeCompletionProvider {
178    fn base_model(&self) -> Box<dyn LanguageModel> {
179        let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
180        model
181    }
182    fn complete(
183        &self,
184        _prompt: Box<dyn CompletionRequest>,
185    ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
186        let (tx, rx) = mpsc::channel(1);
187        *self.last_completion_tx.lock() = Some(tx);
188        async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
189    }
190    fn box_clone(&self) -> Box<dyn CompletionProvider> {
191        Box::new((*self).clone())
192    }
193}