ai.rs

  1use anyhow::{anyhow, Context as _, Result};
  2use rpc::proto;
  3use util::ResultExt as _;
  4
  5pub fn language_model_request_to_open_ai(
  6    request: proto::CompleteWithLanguageModel,
  7) -> Result<open_ai::Request> {
  8    Ok(open_ai::Request {
  9        model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
 10        messages: request
 11            .messages
 12            .into_iter()
 13            .map(|message: proto::LanguageModelRequestMessage| {
 14                let role = proto::LanguageModelRole::from_i32(message.role)
 15                    .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
 16
 17                let openai_message = match role {
 18                    proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User {
 19                        content: message.content,
 20                    },
 21                    proto::LanguageModelRole::LanguageModelAssistant => {
 22                        open_ai::RequestMessage::Assistant {
 23                            content: Some(message.content),
 24                            tool_calls: message
 25                                .tool_calls
 26                                .into_iter()
 27                                .filter_map(|call| {
 28                                    Some(open_ai::ToolCall {
 29                                        id: call.id,
 30                                        content: match call.variant? {
 31                                            proto::tool_call::Variant::Function(f) => {
 32                                                open_ai::ToolCallContent::Function {
 33                                                    function: open_ai::FunctionContent {
 34                                                        name: f.name,
 35                                                        arguments: f.arguments,
 36                                                    },
 37                                                }
 38                                            }
 39                                        },
 40                                    })
 41                                })
 42                                .collect(),
 43                        }
 44                    }
 45                    proto::LanguageModelRole::LanguageModelSystem => {
 46                        open_ai::RequestMessage::System {
 47                            content: message.content,
 48                        }
 49                    }
 50                    proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool {
 51                        tool_call_id: message
 52                            .tool_call_id
 53                            .ok_or_else(|| anyhow!("tool message is missing tool call id"))?,
 54                        content: message.content,
 55                    },
 56                };
 57
 58                Ok(openai_message)
 59            })
 60            .collect::<Result<Vec<open_ai::RequestMessage>>>()?,
 61        stream: true,
 62        stop: request.stop,
 63        temperature: request.temperature,
 64        tools: request
 65            .tools
 66            .into_iter()
 67            .filter_map(|tool| {
 68                Some(match tool.variant? {
 69                    proto::chat_completion_tool::Variant::Function(f) => {
 70                        open_ai::ToolDefinition::Function {
 71                            function: open_ai::FunctionDefinition {
 72                                name: f.name,
 73                                description: f.description,
 74                                parameters: if let Some(params) = &f.parameters {
 75                                    Some(
 76                                        serde_json::from_str(params)
 77                                            .context("failed to deserialize tool parameters")
 78                                            .log_err()?,
 79                                    )
 80                                } else {
 81                                    None
 82                                },
 83                            },
 84                        }
 85                    }
 86                })
 87            })
 88            .collect(),
 89        tool_choice: request.tool_choice,
 90    })
 91}
 92
 93pub fn language_model_request_to_google_ai(
 94    request: proto::CompleteWithLanguageModel,
 95) -> Result<google_ai::GenerateContentRequest> {
 96    Ok(google_ai::GenerateContentRequest {
 97        contents: request
 98            .messages
 99            .into_iter()
100            .map(language_model_request_message_to_google_ai)
101            .collect::<Result<Vec<_>>>()?,
102        generation_config: None,
103        safety_settings: None,
104    })
105}
106
107pub fn language_model_request_message_to_google_ai(
108    message: proto::LanguageModelRequestMessage,
109) -> Result<google_ai::Content> {
110    let role = proto::LanguageModelRole::from_i32(message.role)
111        .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
112
113    Ok(google_ai::Content {
114        parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
115            text: message.content,
116        })],
117        role: match role {
118            proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
119            proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
120            proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
121            proto::LanguageModelRole::LanguageModelTool => {
122                Err(anyhow!("we don't handle tool calls with google ai yet"))?
123            }
124        },
125    })
126}
127
128pub fn count_tokens_request_to_google_ai(
129    request: proto::CountTokensWithLanguageModel,
130) -> Result<google_ai::CountTokensRequest> {
131    Ok(google_ai::CountTokensRequest {
132        contents: request
133            .messages
134            .into_iter()
135            .map(language_model_request_message_to_google_ai)
136            .collect::<Result<Vec<_>>>()?,
137    })
138}