test.rs

  1use std::{
  2    sync::atomic::{self, AtomicUsize, Ordering},
  3    time::Instant,
  4};
  5
  6use async_trait::async_trait;
  7use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  8use parking_lot::Mutex;
  9
 10use crate::{
 11    auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
 12    completion::{CompletionProvider, CompletionRequest},
 13    embedding::{Embedding, EmbeddingProvider},
 14    models::{LanguageModel, TruncationDirection},
 15};
 16
 17#[derive(Clone)]
 18pub struct FakeLanguageModel {
 19    pub capacity: usize,
 20}
 21
 22impl LanguageModel for FakeLanguageModel {
 23    fn name(&self) -> String {
 24        "dummy".to_string()
 25    }
 26    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
 27        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
 28    }
 29    fn truncate(
 30        &self,
 31        content: &str,
 32        length: usize,
 33        direction: TruncationDirection,
 34    ) -> anyhow::Result<String> {
 35        if length > self.count_tokens(content)? {
 36            return anyhow::Ok(content.to_string());
 37        }
 38
 39        anyhow::Ok(match direction {
 40            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
 41                .into_iter()
 42                .collect::<String>(),
 43            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
 44                .into_iter()
 45                .collect::<String>(),
 46        })
 47    }
 48    fn capacity(&self) -> anyhow::Result<usize> {
 49        anyhow::Ok(self.capacity)
 50    }
 51}
 52
 53pub struct FakeEmbeddingProvider {
 54    pub embedding_count: AtomicUsize,
 55    pub credential_provider: NullCredentialProvider,
 56}
 57
 58impl Clone for FakeEmbeddingProvider {
 59    fn clone(&self) -> Self {
 60        FakeEmbeddingProvider {
 61            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
 62            credential_provider: self.credential_provider.clone(),
 63        }
 64    }
 65}
 66
 67impl Default for FakeEmbeddingProvider {
 68    fn default() -> Self {
 69        FakeEmbeddingProvider {
 70            embedding_count: AtomicUsize::default(),
 71            credential_provider: NullCredentialProvider {},
 72        }
 73    }
 74}
 75
 76impl FakeEmbeddingProvider {
 77    pub fn embedding_count(&self) -> usize {
 78        self.embedding_count.load(atomic::Ordering::SeqCst)
 79    }
 80
 81    pub fn embed_sync(&self, span: &str) -> Embedding {
 82        let mut result = vec![1.0; 26];
 83        for letter in span.chars() {
 84            let letter = letter.to_ascii_lowercase();
 85            if letter as u32 >= 'a' as u32 {
 86                let ix = (letter as u32) - ('a' as u32);
 87                if ix < 26 {
 88                    result[ix as usize] += 1.0;
 89                }
 90            }
 91        }
 92
 93        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
 94        for x in &mut result {
 95            *x /= norm;
 96        }
 97
 98        result.into()
 99    }
100}
101
102#[async_trait]
103impl EmbeddingProvider for FakeEmbeddingProvider {
104    fn base_model(&self) -> Box<dyn LanguageModel> {
105        Box::new(FakeLanguageModel { capacity: 1000 })
106    }
107    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
108        let credential_provider: Box<dyn CredentialProvider> =
109            Box::new(self.credential_provider.clone());
110        credential_provider
111    }
112    fn max_tokens_per_batch(&self) -> usize {
113        1000
114    }
115
116    fn rate_limit_expiration(&self) -> Option<Instant> {
117        None
118    }
119
120    async fn embed_batch(
121        &self,
122        spans: Vec<String>,
123        _credential: ProviderCredential,
124    ) -> anyhow::Result<Vec<Embedding>> {
125        self.embedding_count
126            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
127
128        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
129    }
130}
131
132pub struct TestCompletionProvider {
133    last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
134}
135
136impl TestCompletionProvider {
137    pub fn new() -> Self {
138        Self {
139            last_completion_tx: Mutex::new(None),
140        }
141    }
142
143    pub fn send_completion(&self, completion: impl Into<String>) {
144        let mut tx = self.last_completion_tx.lock();
145        tx.as_mut().unwrap().try_send(completion.into()).unwrap();
146    }
147
148    pub fn finish_completion(&self) {
149        self.last_completion_tx.lock().take().unwrap();
150    }
151}
152
153impl CompletionProvider for TestCompletionProvider {
154    fn base_model(&self) -> Box<dyn LanguageModel> {
155        let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
156        model
157    }
158    fn complete(
159        &self,
160        _prompt: Box<dyn CompletionRequest>,
161    ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
162        let (tx, rx) = mpsc::channel(1);
163        *self.last_completion_tx.lock() = Some(tx);
164        async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
165    }
166}