models.rs

 1use anyhow::anyhow;
 2use tiktoken_rs::CoreBPE;
 3use util::ResultExt;
 4
 5pub trait LanguageModel {
 6    fn name(&self) -> String;
 7    fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
 8    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
 9    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
10    fn capacity(&self) -> anyhow::Result<usize>;
11}
12
13pub struct OpenAILanguageModel {
14    name: String,
15    bpe: Option<CoreBPE>,
16}
17
18impl OpenAILanguageModel {
19    pub fn load(model_name: &str) -> Self {
20        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
21        OpenAILanguageModel {
22            name: model_name.to_string(),
23            bpe,
24        }
25    }
26}
27
28impl LanguageModel for OpenAILanguageModel {
29    fn name(&self) -> String {
30        self.name.clone()
31    }
32    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
33        if let Some(bpe) = &self.bpe {
34            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
35        } else {
36            Err(anyhow!("bpe for open ai model was not retrieved"))
37        }
38    }
39    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
40        if let Some(bpe) = &self.bpe {
41            let tokens = bpe.encode_with_special_tokens(content);
42            if tokens.len() > length {
43                bpe.decode(tokens[..length].to_vec())
44            } else {
45                bpe.decode(tokens)
46            }
47        } else {
48            Err(anyhow!("bpe for open ai model was not retrieved"))
49        }
50    }
51    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
52        if let Some(bpe) = &self.bpe {
53            let tokens = bpe.encode_with_special_tokens(content);
54            if tokens.len() > length {
55                bpe.decode(tokens[length..].to_vec())
56            } else {
57                bpe.decode(tokens)
58            }
59        } else {
60            Err(anyhow!("bpe for open ai model was not retrieved"))
61        }
62    }
63    fn capacity(&self) -> anyhow::Result<usize> {
64        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
65    }
66}