model.rs

 1use anyhow::anyhow;
 2use tiktoken_rs::CoreBPE;
 3
 4use crate::models::{LanguageModel, TruncationDirection};
 5
 6use super::open_ai_bpe_tokenizer;
 7
 8#[derive(Clone)]
 9pub struct OpenAiLanguageModel {
10    name: String,
11    bpe: Option<CoreBPE>,
12}
13
14impl OpenAiLanguageModel {
15    pub fn load(model_name: &str) -> Self {
16        let bpe = tiktoken_rs::get_bpe_from_model(model_name)
17            .unwrap_or(open_ai_bpe_tokenizer().to_owned());
18        OpenAiLanguageModel {
19            name: model_name.to_string(),
20            bpe: Some(bpe),
21        }
22    }
23}
24
25impl LanguageModel for OpenAiLanguageModel {
26    fn name(&self) -> String {
27        self.name.clone()
28    }
29    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
30        if let Some(bpe) = &self.bpe {
31            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
32        } else {
33            Err(anyhow!("bpe for open ai model was not retrieved"))
34        }
35    }
36    fn truncate(
37        &self,
38        content: &str,
39        length: usize,
40        direction: TruncationDirection,
41    ) -> anyhow::Result<String> {
42        if let Some(bpe) = &self.bpe {
43            let tokens = bpe.encode_with_special_tokens(content);
44            if tokens.len() > length {
45                match direction {
46                    TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
47                    TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
48                }
49            } else {
50                bpe.decode(tokens)
51            }
52        } else {
53            Err(anyhow!("bpe for open ai model was not retrieved"))
54        }
55    }
56    fn capacity(&self) -> anyhow::Result<usize> {
57        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
58    }
59}