1use anyhow::anyhow;
2use tiktoken_rs::CoreBPE;
3use util::ResultExt;
4
5pub trait LanguageModel {
6 fn name(&self) -> String;
7 fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
8 fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
9 fn capacity(&self) -> anyhow::Result<usize>;
10}
11
12pub struct OpenAILanguageModel {
13 name: String,
14 bpe: Option<CoreBPE>,
15}
16
17impl OpenAILanguageModel {
18 pub fn load(model_name: &str) -> Self {
19 let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
20 OpenAILanguageModel {
21 name: model_name.to_string(),
22 bpe,
23 }
24 }
25}
26
27impl LanguageModel for OpenAILanguageModel {
28 fn name(&self) -> String {
29 self.name.clone()
30 }
31 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
32 if let Some(bpe) = &self.bpe {
33 anyhow::Ok(bpe.encode_with_special_tokens(content).len())
34 } else {
35 Err(anyhow!("bpe for open ai model was not retrieved"))
36 }
37 }
38 fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
39 if let Some(bpe) = &self.bpe {
40 let tokens = bpe.encode_with_special_tokens(content);
41 if tokens.len() > length {
42 bpe.decode(tokens[..length].to_vec())
43 } else {
44 bpe.decode(tokens)
45 }
46 } else {
47 Err(anyhow!("bpe for open ai model was not retrieved"))
48 }
49 }
50 fn capacity(&self) -> anyhow::Result<usize> {
51 anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
52 }
53}