models.rs

 1use anyhow::anyhow;
 2use tiktoken_rs::CoreBPE;
 3use util::ResultExt;
 4
 5pub trait LanguageModel {
 6    fn name(&self) -> String;
 7    fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
 8    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
 9    fn capacity(&self) -> anyhow::Result<usize>;
10}
11
12pub struct OpenAILanguageModel {
13    name: String,
14    bpe: Option<CoreBPE>,
15}
16
17impl OpenAILanguageModel {
18    pub fn load(model_name: &str) -> Self {
19        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
20        OpenAILanguageModel {
21            name: model_name.to_string(),
22            bpe,
23        }
24    }
25}
26
27impl LanguageModel for OpenAILanguageModel {
28    fn name(&self) -> String {
29        self.name.clone()
30    }
31    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
32        if let Some(bpe) = &self.bpe {
33            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
34        } else {
35            Err(anyhow!("bpe for open ai model was not retrieved"))
36        }
37    }
38    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
39        if let Some(bpe) = &self.bpe {
40            let tokens = bpe.encode_with_special_tokens(content);
41            if tokens.len() > length {
42                bpe.decode(tokens[..length].to_vec())
43            } else {
44                bpe.decode(tokens)
45            }
46        } else {
47            Err(anyhow!("bpe for open ai model was not retrieved"))
48        }
49    }
50    fn capacity(&self) -> anyhow::Result<usize> {
51        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
52    }
53}