model.rs

 1use anyhow::anyhow;
 2use tiktoken_rs::CoreBPE;
 3use util::ResultExt;
 4
 5use crate::models::{LanguageModel, TruncationDirection};
 6
 7#[derive(Clone)]
 8pub struct OpenAILanguageModel {
 9    name: String,
10    bpe: Option<CoreBPE>,
11}
12
13impl OpenAILanguageModel {
14    pub fn load(model_name: &str) -> Self {
15        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
16        OpenAILanguageModel {
17            name: model_name.to_string(),
18            bpe,
19        }
20    }
21}
22
23impl LanguageModel for OpenAILanguageModel {
24    fn name(&self) -> String {
25        self.name.clone()
26    }
27    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
28        if let Some(bpe) = &self.bpe {
29            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
30        } else {
31            Err(anyhow!("bpe for open ai model was not retrieved"))
32        }
33    }
34    fn truncate(
35        &self,
36        content: &str,
37        length: usize,
38        direction: TruncationDirection,
39    ) -> anyhow::Result<String> {
40        if let Some(bpe) = &self.bpe {
41            let tokens = bpe.encode_with_special_tokens(content);
42            if tokens.len() > length {
43                match direction {
44                    TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
45                    TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
46                }
47            } else {
48                bpe.decode(tokens)
49            }
50        } else {
51            Err(anyhow!("bpe for open ai model was not retrieved"))
52        }
53    }
54    fn capacity(&self) -> anyhow::Result<usize> {
55        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
56    }
57}