dummy.rs

 1use std::time::Instant;
 2
 3use crate::{
 4    completion::CompletionRequest,
 5    embedding::{Embedding, EmbeddingProvider},
 6    models::{LanguageModel, TruncationDirection},
 7};
 8use async_trait::async_trait;
 9use serde::Serialize;
10
11pub struct DummyLanguageModel {}
12
13impl LanguageModel for DummyLanguageModel {
14    fn name(&self) -> String {
15        "dummy".to_string()
16    }
17    fn capacity(&self) -> anyhow::Result<usize> {
18        anyhow::Ok(1000)
19    }
20    fn truncate(
21        &self,
22        content: &str,
23        length: usize,
24        direction: crate::models::TruncationDirection,
25    ) -> anyhow::Result<String> {
26        let truncated = match direction {
27            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
28                .iter()
29                .collect::<String>(),
30            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[..length]
31                .iter()
32                .collect::<String>(),
33        };
34
35        anyhow::Ok(truncated)
36    }
37    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
38        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
39    }
40}
41
42#[derive(Serialize)]
43pub struct DummyCompletionRequest {
44    pub name: String,
45}
46
47impl CompletionRequest for DummyCompletionRequest {
48    fn data(&self) -> serde_json::Result<String> {
49        serde_json::to_string(self)
50    }
51}
52
53pub struct DummyEmbeddingProvider {}
54
55#[async_trait]
56impl EmbeddingProvider for DummyEmbeddingProvider {
57    fn base_model(&self) -> Box<dyn LanguageModel> {
58        Box::new(DummyLanguageModel {})
59    }
60    fn is_authenticated(&self) -> bool {
61        true
62    }
63    fn rate_limit_expiration(&self) -> Option<Instant> {
64        None
65    }
66    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
67        // 1024 is the OpenAI Embeddings size for ada models.
68        // the model we will likely be starting with.
69        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
70        return Ok(vec![dummy_vec; spans.len()]);
71    }
72
73    fn max_tokens_per_batch(&self) -> usize {
74        8190
75    }
76
77    fn truncate(&self, span: &str) -> (String, usize) {
78        let truncated = span.chars().collect::<Vec<char>>()[..8190]
79            .iter()
80            .collect::<String>();
81        (truncated, 8190)
82    }
83}