test.rs

  1use std::{
  2    sync::atomic::{self, AtomicUsize, Ordering},
  3    time::Instant,
  4};
  5
  6use async_trait::async_trait;
  7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  8use gpui::AppContext;
  9use parking_lot::Mutex;
 10
 11use crate::{
 12    auth::{CredentialProvider, ProviderCredential},
 13    completion::{CompletionProvider, CompletionRequest},
 14    embedding::{Embedding, EmbeddingProvider},
 15    models::{LanguageModel, TruncationDirection},
 16};
 17
 18#[derive(Clone)]
 19pub struct FakeLanguageModel {
 20    pub capacity: usize,
 21}
 22
 23impl LanguageModel for FakeLanguageModel {
 24    fn name(&self) -> String {
 25        "dummy".to_string()
 26    }
 27    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
 28        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
 29    }
 30    fn truncate(
 31        &self,
 32        content: &str,
 33        length: usize,
 34        direction: TruncationDirection,
 35    ) -> anyhow::Result<String> {
 36        println!("TRYING TO TRUNCATE: {:?}", length.clone());
 37
 38        if length > self.count_tokens(content)? {
 39            println!("NOT TRUNCATING");
 40            return anyhow::Ok(content.to_string());
 41        }
 42
 43        anyhow::Ok(match direction {
 44            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
 45                .into_iter()
 46                .collect::<String>(),
 47            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
 48                .into_iter()
 49                .collect::<String>(),
 50        })
 51    }
 52    fn capacity(&self) -> anyhow::Result<usize> {
 53        anyhow::Ok(self.capacity)
 54    }
 55}
 56
 57#[derive(Default)]
 58pub struct FakeEmbeddingProvider {
 59    pub embedding_count: AtomicUsize,
 60}
 61
 62impl Clone for FakeEmbeddingProvider {
 63    fn clone(&self) -> Self {
 64        FakeEmbeddingProvider {
 65            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
 66        }
 67    }
 68}
 69
 70impl FakeEmbeddingProvider {
 71    pub fn embedding_count(&self) -> usize {
 72        self.embedding_count.load(atomic::Ordering::SeqCst)
 73    }
 74
 75    pub fn embed_sync(&self, span: &str) -> Embedding {
 76        let mut result = vec![1.0; 26];
 77        for letter in span.chars() {
 78            let letter = letter.to_ascii_lowercase();
 79            if letter as u32 >= 'a' as u32 {
 80                let ix = (letter as u32) - ('a' as u32);
 81                if ix < 26 {
 82                    result[ix as usize] += 1.0;
 83                }
 84            }
 85        }
 86
 87        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
 88        for x in &mut result {
 89            *x /= norm;
 90        }
 91
 92        result.into()
 93    }
 94}
 95
 96impl CredentialProvider for FakeEmbeddingProvider {
 97    fn has_credentials(&self) -> bool {
 98        true
 99    }
100
101    fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
102        async { ProviderCredential::NotNeeded }.boxed()
103    }
104
105    fn save_credentials(
106        &self,
107        _cx: &mut AppContext,
108        _credential: ProviderCredential,
109    ) -> BoxFuture<()> {
110        async {}.boxed()
111    }
112
113    fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
114        async {}.boxed()
115    }
116}
117
118#[async_trait]
119impl EmbeddingProvider for FakeEmbeddingProvider {
120    fn base_model(&self) -> Box<dyn LanguageModel> {
121        Box::new(FakeLanguageModel { capacity: 1000 })
122    }
123    fn max_tokens_per_batch(&self) -> usize {
124        1000
125    }
126
127    fn rate_limit_expiration(&self) -> Option<Instant> {
128        None
129    }
130
131    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
132        self.embedding_count
133            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
134
135        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
136    }
137}
138
139pub struct FakeCompletionProvider {
140    last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
141}
142
143impl Clone for FakeCompletionProvider {
144    fn clone(&self) -> Self {
145        Self {
146            last_completion_tx: Mutex::new(None),
147        }
148    }
149}
150
151impl FakeCompletionProvider {
152    pub fn new() -> Self {
153        Self {
154            last_completion_tx: Mutex::new(None),
155        }
156    }
157
158    pub fn send_completion(&self, completion: impl Into<String>) {
159        let mut tx = self.last_completion_tx.lock();
160        tx.as_mut().unwrap().try_send(completion.into()).unwrap();
161    }
162
163    pub fn finish_completion(&self) {
164        self.last_completion_tx.lock().take().unwrap();
165    }
166}
167
168impl CredentialProvider for FakeCompletionProvider {
169    fn has_credentials(&self) -> bool {
170        true
171    }
172
173    fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
174        async { ProviderCredential::NotNeeded }.boxed()
175    }
176
177    fn save_credentials(
178        &self,
179        _cx: &mut AppContext,
180        _credential: ProviderCredential,
181    ) -> BoxFuture<()> {
182        async {}.boxed()
183    }
184
185    fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
186        async {}.boxed()
187    }
188}
189
190impl CompletionProvider for FakeCompletionProvider {
191    fn base_model(&self) -> Box<dyn LanguageModel> {
192        let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
193        model
194    }
195    fn complete(
196        &self,
197        _prompt: Box<dyn CompletionRequest>,
198    ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
199        let (tx, rx) = mpsc::channel(1);
200        *self.last_completion_tx.lock() = Some(tx);
201        async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
202    }
203    fn box_clone(&self) -> Box<dyn CompletionProvider> {
204        Box::new((*self).clone())
205    }
206}