1use anyhow::{anyhow, Result};
2use rpc::proto;
3
4pub fn language_model_request_to_open_ai(
5 request: proto::CompleteWithLanguageModel,
6) -> Result<open_ai::Request> {
7 Ok(open_ai::Request {
8 model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
9 messages: request
10 .messages
11 .into_iter()
12 .map(|message| {
13 let role = proto::LanguageModelRole::from_i32(message.role)
14 .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
15 Ok(open_ai::RequestMessage {
16 role: match role {
17 proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
18 proto::LanguageModelRole::LanguageModelAssistant => {
19 open_ai::Role::Assistant
20 }
21 proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
22 },
23 content: message.content,
24 })
25 })
26 .collect::<Result<Vec<open_ai::RequestMessage>>>()?,
27 stream: true,
28 stop: request.stop,
29 temperature: request.temperature,
30 })
31}
32
33pub fn language_model_request_to_google_ai(
34 request: proto::CompleteWithLanguageModel,
35) -> Result<google_ai::GenerateContentRequest> {
36 Ok(google_ai::GenerateContentRequest {
37 contents: request
38 .messages
39 .into_iter()
40 .map(language_model_request_message_to_google_ai)
41 .collect::<Result<Vec<_>>>()?,
42 generation_config: None,
43 safety_settings: None,
44 })
45}
46
47pub fn language_model_request_message_to_google_ai(
48 message: proto::LanguageModelRequestMessage,
49) -> Result<google_ai::Content> {
50 let role = proto::LanguageModelRole::from_i32(message.role)
51 .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
52
53 Ok(google_ai::Content {
54 parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
55 text: message.content,
56 })],
57 role: match role {
58 proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
59 proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
60 proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
61 },
62 })
63}
64
65pub fn count_tokens_request_to_google_ai(
66 request: proto::CountTokensWithLanguageModel,
67) -> Result<google_ai::CountTokensRequest> {
68 Ok(google_ai::CountTokensRequest {
69 contents: request
70 .messages
71 .into_iter()
72 .map(language_model_request_message_to_google_ai)
73 .collect::<Result<Vec<_>>>()?,
74 })
75}