test.rs

  1use std::{
  2    sync::atomic::{self, AtomicUsize, Ordering},
  3    time::Instant,
  4};
  5
  6use async_trait::async_trait;
  7
  8use crate::{
  9    auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
 10    embedding::{Embedding, EmbeddingProvider},
 11    models::{LanguageModel, TruncationDirection},
 12};
 13
 14#[derive(Clone)]
 15pub struct FakeLanguageModel {
 16    pub capacity: usize,
 17}
 18
 19impl LanguageModel for FakeLanguageModel {
 20    fn name(&self) -> String {
 21        "dummy".to_string()
 22    }
 23    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
 24        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
 25    }
 26    fn truncate(
 27        &self,
 28        content: &str,
 29        length: usize,
 30        direction: TruncationDirection,
 31    ) -> anyhow::Result<String> {
 32        if length > self.count_tokens(content)? {
 33            return anyhow::Ok(content.to_string());
 34        }
 35
 36        anyhow::Ok(match direction {
 37            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
 38                .into_iter()
 39                .collect::<String>(),
 40            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
 41                .into_iter()
 42                .collect::<String>(),
 43        })
 44    }
 45    fn capacity(&self) -> anyhow::Result<usize> {
 46        anyhow::Ok(self.capacity)
 47    }
 48}
 49
 50pub struct FakeEmbeddingProvider {
 51    pub embedding_count: AtomicUsize,
 52    pub credential_provider: NullCredentialProvider,
 53}
 54
 55impl Clone for FakeEmbeddingProvider {
 56    fn clone(&self) -> Self {
 57        FakeEmbeddingProvider {
 58            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
 59            credential_provider: self.credential_provider.clone(),
 60        }
 61    }
 62}
 63
 64impl Default for FakeEmbeddingProvider {
 65    fn default() -> Self {
 66        FakeEmbeddingProvider {
 67            embedding_count: AtomicUsize::default(),
 68            credential_provider: NullCredentialProvider {},
 69        }
 70    }
 71}
 72
 73impl FakeEmbeddingProvider {
 74    pub fn embedding_count(&self) -> usize {
 75        self.embedding_count.load(atomic::Ordering::SeqCst)
 76    }
 77
 78    pub fn embed_sync(&self, span: &str) -> Embedding {
 79        let mut result = vec![1.0; 26];
 80        for letter in span.chars() {
 81            let letter = letter.to_ascii_lowercase();
 82            if letter as u32 >= 'a' as u32 {
 83                let ix = (letter as u32) - ('a' as u32);
 84                if ix < 26 {
 85                    result[ix as usize] += 1.0;
 86                }
 87            }
 88        }
 89
 90        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
 91        for x in &mut result {
 92            *x /= norm;
 93        }
 94
 95        result.into()
 96    }
 97}
 98
 99#[async_trait]
100impl EmbeddingProvider for FakeEmbeddingProvider {
101    fn base_model(&self) -> Box<dyn LanguageModel> {
102        Box::new(FakeLanguageModel { capacity: 1000 })
103    }
104    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
105        let credential_provider: Box<dyn CredentialProvider> =
106            Box::new(self.credential_provider.clone());
107        credential_provider
108    }
109    fn max_tokens_per_batch(&self) -> usize {
110        1000
111    }
112
113    fn rate_limit_expiration(&self) -> Option<Instant> {
114        None
115    }
116
117    async fn embed_batch(
118        &self,
119        spans: Vec<String>,
120        _credential: ProviderCredential,
121    ) -> anyhow::Result<Vec<Embedding>> {
122        self.embedding_count
123            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
124
125        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
126    }
127}