1use anyhow::anyhow;
2use tiktoken_rs::CoreBPE;
3use util::ResultExt;
4
5use crate::models::{LanguageModel, TruncationDirection};
6
7#[derive(Clone)]
8pub struct OpenAILanguageModel {
9 name: String,
10 bpe: Option<CoreBPE>,
11}
12
13impl OpenAILanguageModel {
14 pub fn load(model_name: &str) -> Self {
15 let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
16 OpenAILanguageModel {
17 name: model_name.to_string(),
18 bpe,
19 }
20 }
21}
22
23impl LanguageModel for OpenAILanguageModel {
24 fn name(&self) -> String {
25 self.name.clone()
26 }
27 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
28 if let Some(bpe) = &self.bpe {
29 anyhow::Ok(bpe.encode_with_special_tokens(content).len())
30 } else {
31 Err(anyhow!("bpe for open ai model was not retrieved"))
32 }
33 }
34 fn truncate(
35 &self,
36 content: &str,
37 length: usize,
38 direction: TruncationDirection,
39 ) -> anyhow::Result<String> {
40 if let Some(bpe) = &self.bpe {
41 let tokens = bpe.encode_with_special_tokens(content);
42 if tokens.len() > length {
43 match direction {
44 TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
45 TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
46 }
47 } else {
48 bpe.decode(tokens)
49 }
50 } else {
51 Err(anyhow!("bpe for open ai model was not retrieved"))
52 }
53 }
54 fn capacity(&self) -> anyhow::Result<usize> {
55 anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
56 }
57}