models.rs

 1use anyhow::anyhow;
 2use tiktoken_rs::CoreBPE;
 3use util::ResultExt;
 4
 5pub trait LanguageModel {
 6    fn name(&self) -> String;
 7    fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
 8    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
 9    fn capacity(&self) -> anyhow::Result<usize>;
10}
11
12pub struct OpenAILanguageModel {
13    name: String,
14    bpe: Option<CoreBPE>,
15}
16
17impl OpenAILanguageModel {
18    pub fn load(model_name: &str) -> Self {
19        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
20        OpenAILanguageModel {
21            name: model_name.to_string(),
22            bpe,
23        }
24    }
25}
26
27impl LanguageModel for OpenAILanguageModel {
28    fn name(&self) -> String {
29        self.name.clone()
30    }
31    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
32        if let Some(bpe) = &self.bpe {
33            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
34        } else {
35            Err(anyhow!("bpe for open ai model was not retrieved"))
36        }
37    }
38    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
39        if let Some(bpe) = &self.bpe {
40            let tokens = bpe.encode_with_special_tokens(content);
41            bpe.decode(tokens[..length].to_vec())
42        } else {
43            Err(anyhow!("bpe for open ai model was not retrieved"))
44        }
45    }
46    fn capacity(&self) -> anyhow::Result<usize> {
47        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
48    }
49}