1use anyhow::anyhow;
2use tiktoken_rs::CoreBPE;
3use util::ResultExt;
4
5pub trait LanguageModel {
6 fn name(&self) -> String;
7 fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
8 fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
9 fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
10 fn capacity(&self) -> anyhow::Result<usize>;
11}
12
13pub struct OpenAILanguageModel {
14 name: String,
15 bpe: Option<CoreBPE>,
16}
17
18impl OpenAILanguageModel {
19 pub fn load(model_name: &str) -> Self {
20 let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
21 OpenAILanguageModel {
22 name: model_name.to_string(),
23 bpe,
24 }
25 }
26}
27
28impl LanguageModel for OpenAILanguageModel {
29 fn name(&self) -> String {
30 self.name.clone()
31 }
32 fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
33 if let Some(bpe) = &self.bpe {
34 anyhow::Ok(bpe.encode_with_special_tokens(content).len())
35 } else {
36 Err(anyhow!("bpe for open ai model was not retrieved"))
37 }
38 }
39 fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
40 if let Some(bpe) = &self.bpe {
41 let tokens = bpe.encode_with_special_tokens(content);
42 if tokens.len() > length {
43 bpe.decode(tokens[..length].to_vec())
44 } else {
45 bpe.decode(tokens)
46 }
47 } else {
48 Err(anyhow!("bpe for open ai model was not retrieved"))
49 }
50 }
51 fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
52 if let Some(bpe) = &self.bpe {
53 let tokens = bpe.encode_with_special_tokens(content);
54 if tokens.len() > length {
55 bpe.decode(tokens[length..].to_vec())
56 } else {
57 bpe.decode(tokens)
58 }
59 } else {
60 Err(anyhow!("bpe for open ai model was not retrieved"))
61 }
62 }
63 fn capacity(&self) -> anyhow::Result<usize> {
64 anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
65 }
66}