1use std::time::Instant;
2
3use crate::{
4 completion::CompletionRequest,
5 embedding::{Embedding, EmbeddingProvider},
6 models::{LanguageModel, TruncationDirection},
7};
8use async_trait::async_trait;
9use serde::Serialize;
10
11pub struct DummyLanguageModel {}
12
13impl LanguageModel for DummyLanguageModel {
14 fn name(&self) -> String {
15 "dummy".to_string()
16 }
17 fn capacity(&self) -> anyhow::Result<usize> {
18 anyhow::Ok(1000)
19 }
20 fn truncate(
21 &self,
22 content: &str,
23 length: usize,
24 direction: crate::models::TruncationDirection,
25 ) -> anyhow::Result<String> {
26 let truncated = match direction {
27 TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
28 .iter()
29 .collect::<String>(),
30 TruncationDirection::Start => content.chars().collect::<Vec<char>>()[..length]
31 .iter()
32 .collect::<String>(),
33 };
34
35 anyhow::Ok(truncated)
36 }
37 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
38 anyhow::Ok(content.chars().collect::<Vec<char>>().len())
39 }
40}
41
42#[derive(Serialize)]
43pub struct DummyCompletionRequest {
44 pub name: String,
45}
46
47impl CompletionRequest for DummyCompletionRequest {
48 fn data(&self) -> serde_json::Result<String> {
49 serde_json::to_string(self)
50 }
51}
52
53pub struct DummyEmbeddingProvider {}
54
55#[async_trait]
56impl EmbeddingProvider for DummyEmbeddingProvider {
57 fn base_model(&self) -> Box<dyn LanguageModel> {
58 Box::new(DummyLanguageModel {})
59 }
60 fn is_authenticated(&self) -> bool {
61 true
62 }
63 fn rate_limit_expiration(&self) -> Option<Instant> {
64 None
65 }
66 async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
67 // 1024 is the OpenAI Embeddings size for ada models.
68 // the model we will likely be starting with.
69 let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
70 return Ok(vec![dummy_vec; spans.len()]);
71 }
72
73 fn max_tokens_per_batch(&self) -> usize {
74 8190
75 }
76
77 fn truncate(&self, span: &str) -> (String, usize) {
78 let truncated = span.chars().collect::<Vec<char>>()[..8190]
79 .iter()
80 .collect::<String>();
81 (truncated, 8190)
82 }
83}