ai.rs

 1use anyhow::{anyhow, Result};
 2use rpc::proto;
 3
 4pub fn language_model_request_to_open_ai(
 5    request: proto::CompleteWithLanguageModel,
 6) -> Result<open_ai::Request> {
 7    Ok(open_ai::Request {
 8        model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
 9        messages: request
10            .messages
11            .into_iter()
12            .map(|message| {
13                let role = proto::LanguageModelRole::from_i32(message.role)
14                    .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
15                Ok(open_ai::RequestMessage {
16                    role: match role {
17                        proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
18                        proto::LanguageModelRole::LanguageModelAssistant => {
19                            open_ai::Role::Assistant
20                        }
21                        proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
22                    },
23                    content: message.content,
24                })
25            })
26            .collect::<Result<Vec<open_ai::RequestMessage>>>()?,
27        stream: true,
28        stop: request.stop,
29        temperature: request.temperature,
30    })
31}
32
33pub fn language_model_request_to_google_ai(
34    request: proto::CompleteWithLanguageModel,
35) -> Result<google_ai::GenerateContentRequest> {
36    Ok(google_ai::GenerateContentRequest {
37        contents: request
38            .messages
39            .into_iter()
40            .map(language_model_request_message_to_google_ai)
41            .collect::<Result<Vec<_>>>()?,
42        generation_config: None,
43        safety_settings: None,
44    })
45}
46
47pub fn language_model_request_message_to_google_ai(
48    message: proto::LanguageModelRequestMessage,
49) -> Result<google_ai::Content> {
50    let role = proto::LanguageModelRole::from_i32(message.role)
51        .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
52
53    Ok(google_ai::Content {
54        parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
55            text: message.content,
56        })],
57        role: match role {
58            proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
59            proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
60            proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
61        },
62    })
63}
64
65pub fn count_tokens_request_to_google_ai(
66    request: proto::CountTokensWithLanguageModel,
67) -> Result<google_ai::CountTokensRequest> {
68    Ok(google_ai::CountTokensRequest {
69        contents: request
70            .messages
71            .into_iter()
72            .map(language_model_request_message_to_google_ai)
73            .collect::<Result<Vec<_>>>()?,
74    })
75}