1use crate::{
2 model::{CloudModel, LanguageModel},
3 role::Role,
4};
5use serde::{Deserialize, Serialize};
6
7#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
8pub struct LanguageModelRequestMessage {
9 pub role: Role,
10 pub content: String,
11}
12
13impl LanguageModelRequestMessage {
14 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
15 proto::LanguageModelRequestMessage {
16 role: self.role.to_proto() as i32,
17 content: self.content.clone(),
18 tool_calls: Vec::new(),
19 tool_call_id: None,
20 }
21 }
22}
23
24#[derive(Debug, Default, Serialize, Deserialize)]
25pub struct LanguageModelRequest {
26 pub model: LanguageModel,
27 pub messages: Vec<LanguageModelRequestMessage>,
28 pub stop: Vec<String>,
29 pub temperature: f32,
30}
31
32impl LanguageModelRequest {
33 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
34 proto::CompleteWithLanguageModel {
35 model: self.model.id().to_string(),
36 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
37 stop: self.stop.clone(),
38 temperature: self.temperature,
39 tool_choice: None,
40 tools: Vec::new(),
41 }
42 }
43
44 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
45 pub fn preprocess(&mut self) {
46 match &self.model {
47 LanguageModel::OpenAi(_) => {}
48 LanguageModel::Anthropic(_) => self.preprocess_anthropic(),
49 LanguageModel::Ollama(_) => {}
50 LanguageModel::Cloud(model) => match model {
51 CloudModel::Claude3Opus
52 | CloudModel::Claude3Sonnet
53 | CloudModel::Claude3Haiku
54 | CloudModel::Claude3_5Sonnet => {
55 self.preprocess_anthropic();
56 }
57 CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
58 self.preprocess_anthropic();
59 }
60 _ => {}
61 },
62 }
63 }
64
65 pub fn preprocess_anthropic(&mut self) {
66 let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
67 let mut system_message = String::new();
68
69 for message in self.messages.drain(..) {
70 if message.content.is_empty() {
71 continue;
72 }
73
74 match message.role {
75 Role::User | Role::Assistant => {
76 if let Some(last_message) = new_messages.last_mut() {
77 if last_message.role == message.role {
78 last_message.content.push_str("\n\n");
79 last_message.content.push_str(&message.content);
80 continue;
81 }
82 }
83
84 new_messages.push(message);
85 }
86 Role::System => {
87 if !system_message.is_empty() {
88 system_message.push_str("\n\n");
89 }
90 system_message.push_str(&message.content);
91 }
92 }
93 }
94
95 if !system_message.is_empty() {
96 new_messages.insert(
97 0,
98 LanguageModelRequestMessage {
99 role: Role::System,
100 content: system_message,
101 },
102 );
103 }
104
105 self.messages = new_messages;
106 }
107}
108
109#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
110pub struct LanguageModelResponseMessage {
111 pub role: Option<Role>,
112 pub content: Option<String>,
113}