1use anyhow::anyhow;
2use tiktoken_rs::CoreBPE;
3
4use crate::models::{LanguageModel, TruncationDirection};
5
6use super::open_ai_bpe_tokenizer;
7
8#[derive(Clone)]
9pub struct OpenAiLanguageModel {
10 name: String,
11 bpe: Option<CoreBPE>,
12}
13
14impl OpenAiLanguageModel {
15 pub fn load(model_name: &str) -> Self {
16 let bpe = tiktoken_rs::get_bpe_from_model(model_name)
17 .unwrap_or(open_ai_bpe_tokenizer().to_owned());
18 OpenAiLanguageModel {
19 name: model_name.to_string(),
20 bpe: Some(bpe),
21 }
22 }
23}
24
25impl LanguageModel for OpenAiLanguageModel {
26 fn name(&self) -> String {
27 self.name.clone()
28 }
29 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
30 if let Some(bpe) = &self.bpe {
31 anyhow::Ok(bpe.encode_with_special_tokens(content).len())
32 } else {
33 Err(anyhow!("bpe for open ai model was not retrieved"))
34 }
35 }
36 fn truncate(
37 &self,
38 content: &str,
39 length: usize,
40 direction: TruncationDirection,
41 ) -> anyhow::Result<String> {
42 if let Some(bpe) = &self.bpe {
43 let tokens = bpe.encode_with_special_tokens(content);
44 if tokens.len() > length {
45 match direction {
46 TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
47 TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
48 }
49 } else {
50 bpe.decode(tokens)
51 }
52 } else {
53 Err(anyhow!("bpe for open ai model was not retrieved"))
54 }
55 }
56 fn capacity(&self) -> anyhow::Result<usize> {
57 anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
58 }
59}