1use anyhow::{anyhow, Context as _, Result};
2use rpc::proto;
3use util::ResultExt as _;
4
5pub fn language_model_request_to_open_ai(
6 request: proto::CompleteWithLanguageModel,
7) -> Result<open_ai::Request> {
8 Ok(open_ai::Request {
9 model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
10 messages: request
11 .messages
12 .into_iter()
13 .map(|message: proto::LanguageModelRequestMessage| {
14 let role = proto::LanguageModelRole::from_i32(message.role)
15 .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
16
17 let openai_message = match role {
18 proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User {
19 content: message.content,
20 },
21 proto::LanguageModelRole::LanguageModelAssistant => {
22 open_ai::RequestMessage::Assistant {
23 content: Some(message.content),
24 tool_calls: message
25 .tool_calls
26 .into_iter()
27 .filter_map(|call| {
28 Some(open_ai::ToolCall {
29 id: call.id,
30 content: match call.variant? {
31 proto::tool_call::Variant::Function(f) => {
32 open_ai::ToolCallContent::Function {
33 function: open_ai::FunctionContent {
34 name: f.name,
35 arguments: f.arguments,
36 },
37 }
38 }
39 },
40 })
41 })
42 .collect(),
43 }
44 }
45 proto::LanguageModelRole::LanguageModelSystem => {
46 open_ai::RequestMessage::System {
47 content: message.content,
48 }
49 }
50 proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool {
51 tool_call_id: message
52 .tool_call_id
53 .ok_or_else(|| anyhow!("tool message is missing tool call id"))?,
54 content: message.content,
55 },
56 };
57
58 Ok(openai_message)
59 })
60 .collect::<Result<Vec<open_ai::RequestMessage>>>()?,
61 stream: true,
62 stop: request.stop,
63 temperature: request.temperature,
64 tools: request
65 .tools
66 .into_iter()
67 .filter_map(|tool| {
68 Some(match tool.variant? {
69 proto::chat_completion_tool::Variant::Function(f) => {
70 open_ai::ToolDefinition::Function {
71 function: open_ai::FunctionDefinition {
72 name: f.name,
73 description: f.description,
74 parameters: if let Some(params) = &f.parameters {
75 Some(
76 serde_json::from_str(params)
77 .context("failed to deserialize tool parameters")
78 .log_err()?,
79 )
80 } else {
81 None
82 },
83 },
84 }
85 }
86 })
87 })
88 .collect(),
89 tool_choice: request.tool_choice,
90 })
91}
92
93pub fn language_model_request_to_google_ai(
94 request: proto::CompleteWithLanguageModel,
95) -> Result<google_ai::GenerateContentRequest> {
96 Ok(google_ai::GenerateContentRequest {
97 contents: request
98 .messages
99 .into_iter()
100 .map(language_model_request_message_to_google_ai)
101 .collect::<Result<Vec<_>>>()?,
102 generation_config: None,
103 safety_settings: None,
104 })
105}
106
107pub fn language_model_request_message_to_google_ai(
108 message: proto::LanguageModelRequestMessage,
109) -> Result<google_ai::Content> {
110 let role = proto::LanguageModelRole::from_i32(message.role)
111 .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
112
113 Ok(google_ai::Content {
114 parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
115 text: message.content,
116 })],
117 role: match role {
118 proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
119 proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
120 proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
121 proto::LanguageModelRole::LanguageModelTool => {
122 Err(anyhow!("we don't handle tool calls with google ai yet"))?
123 }
124 },
125 })
126}
127
128pub fn count_tokens_request_to_google_ai(
129 request: proto::CountTokensWithLanguageModel,
130) -> Result<google_ai::CountTokensRequest> {
131 Ok(google_ai::CountTokensRequest {
132 contents: request
133 .messages
134 .into_iter()
135 .map(language_model_request_message_to_google_ai)
136 .collect::<Result<Vec<_>>>()?,
137 })
138}