1use crate::{
2 model::{CloudModel, LanguageModel},
3 role::Role,
4};
5use serde::{Deserialize, Serialize};
6
7#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
8pub struct LanguageModelRequestMessage {
9 pub role: Role,
10 pub content: String,
11}
12
13impl LanguageModelRequestMessage {
14 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
15 proto::LanguageModelRequestMessage {
16 role: self.role.to_proto() as i32,
17 content: self.content.clone(),
18 tool_calls: Vec::new(),
19 tool_call_id: None,
20 }
21 }
22}
23
24#[derive(Debug, Default, Serialize, Deserialize)]
25pub struct LanguageModelRequest {
26 pub model: LanguageModel,
27 pub messages: Vec<LanguageModelRequestMessage>,
28 pub stop: Vec<String>,
29 pub temperature: f32,
30}
31
32impl LanguageModelRequest {
33 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
34 proto::CompleteWithLanguageModel {
35 model: self.model.id().to_string(),
36 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
37 stop: self.stop.clone(),
38 temperature: self.temperature,
39 tool_choice: None,
40 tools: Vec::new(),
41 }
42 }
43
44 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
45 pub fn preprocess(&mut self) {
46 match &self.model {
47 LanguageModel::OpenAi(_) => {}
48 LanguageModel::Anthropic(_) => {}
49 LanguageModel::Ollama(_) => {}
50 LanguageModel::Cloud(model) => match model {
51 CloudModel::Claude3Opus
52 | CloudModel::Claude3Sonnet
53 | CloudModel::Claude3Haiku
54 | CloudModel::Claude3_5Sonnet => {
55 self.preprocess_anthropic();
56 }
57 _ => {}
58 },
59 }
60 }
61
62 pub fn preprocess_anthropic(&mut self) {
63 let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
64 let mut system_message = String::new();
65
66 for message in self.messages.drain(..) {
67 if message.content.is_empty() {
68 continue;
69 }
70
71 match message.role {
72 Role::User | Role::Assistant => {
73 if let Some(last_message) = new_messages.last_mut() {
74 if last_message.role == message.role {
75 last_message.content.push_str("\n\n");
76 last_message.content.push_str(&message.content);
77 continue;
78 }
79 }
80
81 new_messages.push(message);
82 }
83 Role::System => {
84 if !system_message.is_empty() {
85 system_message.push_str("\n\n");
86 }
87 system_message.push_str(&message.content);
88 }
89 }
90 }
91
92 if !system_message.is_empty() {
93 new_messages.insert(
94 0,
95 LanguageModelRequestMessage {
96 role: Role::System,
97 content: system_message,
98 },
99 );
100 }
101
102 self.messages = new_messages;
103 }
104}
105
106#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
107pub struct LanguageModelResponseMessage {
108 pub role: Option<Role>,
109 pub content: Option<String>,
110}