google_ai.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use serde::{Deserialize, Serialize};
  4use util::http::HttpClient;
  5
  6pub const API_URL: &str = "https://generativelanguage.googleapis.com";
  7
  8pub async fn stream_generate_content<T: HttpClient>(
  9    client: &T,
 10    api_url: &str,
 11    api_key: &str,
 12    request: GenerateContentRequest,
 13) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 14    let uri = format!(
 15        "{}/v1beta/models/gemini-pro:streamGenerateContent?alt=sse&key={}",
 16        api_url, api_key
 17    );
 18
 19    let request = serde_json::to_string(&request)?;
 20    let mut response = client.post_json(&uri, request.into()).await?;
 21    if response.status().is_success() {
 22        let reader = BufReader::new(response.into_body());
 23        Ok(reader
 24            .lines()
 25            .filter_map(|line| async move {
 26                match line {
 27                    Ok(line) => {
 28                        if let Some(line) = line.strip_prefix("data: ") {
 29                            match serde_json::from_str(line) {
 30                                Ok(response) => Some(Ok(response)),
 31                                Err(error) => Some(Err(anyhow!(error))),
 32                            }
 33                        } else {
 34                            None
 35                        }
 36                    }
 37                    Err(error) => Some(Err(anyhow!(error))),
 38                }
 39            })
 40            .boxed())
 41    } else {
 42        let mut text = String::new();
 43        response.body_mut().read_to_string(&mut text).await?;
 44        Err(anyhow!(
 45            "error during streamGenerateContent, status code: {:?}, body: {}",
 46            response.status(),
 47            text
 48        ))
 49    }
 50}
 51
 52pub async fn count_tokens<T: HttpClient>(
 53    client: &T,
 54    api_url: &str,
 55    api_key: &str,
 56    request: CountTokensRequest,
 57) -> Result<CountTokensResponse> {
 58    let uri = format!(
 59        "{}/v1beta/models/gemini-pro:countTokens?key={}",
 60        api_url, api_key
 61    );
 62    let request = serde_json::to_string(&request)?;
 63    let mut response = client.post_json(&uri, request.into()).await?;
 64    let mut text = String::new();
 65    response.body_mut().read_to_string(&mut text).await?;
 66    if response.status().is_success() {
 67        Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 68    } else {
 69        Err(anyhow!(
 70            "error during countTokens, status code: {:?}, body: {}",
 71            response.status(),
 72            text
 73        ))
 74    }
 75}
 76
 77#[derive(Debug, Serialize, Deserialize)]
 78pub enum Task {
 79    #[serde(rename = "generateContent")]
 80    GenerateContent,
 81    #[serde(rename = "streamGenerateContent")]
 82    StreamGenerateContent,
 83    #[serde(rename = "countTokens")]
 84    CountTokens,
 85    #[serde(rename = "embedContent")]
 86    EmbedContent,
 87    #[serde(rename = "batchEmbedContents")]
 88    BatchEmbedContents,
 89}
 90
 91#[derive(Debug, Serialize)]
 92#[serde(rename_all = "camelCase")]
 93pub struct GenerateContentRequest {
 94    pub contents: Vec<Content>,
 95    pub generation_config: Option<GenerationConfig>,
 96    pub safety_settings: Option<Vec<SafetySetting>>,
 97}
 98
 99#[derive(Debug, Deserialize)]
100#[serde(rename_all = "camelCase")]
101pub struct GenerateContentResponse {
102    pub candidates: Option<Vec<GenerateContentCandidate>>,
103    pub prompt_feedback: Option<PromptFeedback>,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(rename_all = "camelCase")]
108pub struct GenerateContentCandidate {
109    pub index: usize,
110    pub content: Content,
111    pub finish_reason: Option<String>,
112    pub finish_message: Option<String>,
113    pub safety_ratings: Option<Vec<SafetyRating>>,
114    pub citation_metadata: Option<CitationMetadata>,
115}
116
117#[derive(Debug, Serialize, Deserialize)]
118#[serde(rename_all = "camelCase")]
119pub struct Content {
120    pub parts: Vec<Part>,
121    pub role: Role,
122}
123
124#[derive(Debug, Deserialize, Serialize)]
125#[serde(rename_all = "camelCase")]
126pub enum Role {
127    User,
128    Model,
129}
130
131#[derive(Debug, Serialize, Deserialize)]
132#[serde(untagged)]
133pub enum Part {
134    TextPart(TextPart),
135    InlineDataPart(InlineDataPart),
136}
137
138#[derive(Debug, Serialize, Deserialize)]
139#[serde(rename_all = "camelCase")]
140pub struct TextPart {
141    pub text: String,
142}
143
144#[derive(Debug, Serialize, Deserialize)]
145#[serde(rename_all = "camelCase")]
146pub struct InlineDataPart {
147    pub inline_data: GenerativeContentBlob,
148}
149
150#[derive(Debug, Serialize, Deserialize)]
151#[serde(rename_all = "camelCase")]
152pub struct GenerativeContentBlob {
153    pub mime_type: String,
154    pub data: String,
155}
156
157#[derive(Debug, Deserialize)]
158#[serde(rename_all = "camelCase")]
159pub struct CitationSource {
160    pub start_index: Option<usize>,
161    pub end_index: Option<usize>,
162    pub uri: Option<String>,
163    pub license: Option<String>,
164}
165
166#[derive(Debug, Deserialize)]
167#[serde(rename_all = "camelCase")]
168pub struct CitationMetadata {
169    pub citation_sources: Vec<CitationSource>,
170}
171
172#[derive(Debug, Deserialize)]
173#[serde(rename_all = "camelCase")]
174pub struct PromptFeedback {
175    pub block_reason: Option<String>,
176    pub safety_ratings: Vec<SafetyRating>,
177    pub block_reason_message: Option<String>,
178}
179
180#[derive(Debug, Serialize)]
181#[serde(rename_all = "camelCase")]
182pub struct GenerationConfig {
183    pub candidate_count: Option<usize>,
184    pub stop_sequences: Option<Vec<String>>,
185    pub max_output_tokens: Option<usize>,
186    pub temperature: Option<f64>,
187    pub top_p: Option<f64>,
188    pub top_k: Option<usize>,
189}
190
191#[derive(Debug, Serialize)]
192#[serde(rename_all = "camelCase")]
193pub struct SafetySetting {
194    pub category: HarmCategory,
195    pub threshold: HarmBlockThreshold,
196}
197
198#[derive(Debug, Serialize, Deserialize)]
199pub enum HarmCategory {
200    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
201    Unspecified,
202    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
203    Derogatory,
204    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
205    Toxicity,
206    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
207    Violence,
208    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
209    Sexual,
210    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
211    Medical,
212    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
213    Dangerous,
214    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
215    Harassment,
216    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
217    HateSpeech,
218    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
219    SexuallyExplicit,
220    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
221    DangerousContent,
222}
223
224#[derive(Debug, Serialize)]
225pub enum HarmBlockThreshold {
226    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
227    Unspecified,
228    #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
229    BlockLowAndAbove,
230    #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
231    BlockMediumAndAbove,
232    #[serde(rename = "BLOCK_ONLY_HIGH")]
233    BlockOnlyHigh,
234    #[serde(rename = "BLOCK_NONE")]
235    BlockNone,
236}
237
238#[derive(Debug, Deserialize)]
239#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
240pub enum HarmProbability {
241    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
242    Unspecified,
243    Negligible,
244    Low,
245    Medium,
246    High,
247}
248
249#[derive(Debug, Deserialize)]
250#[serde(rename_all = "camelCase")]
251pub struct SafetyRating {
252    pub category: HarmCategory,
253    pub probability: HarmProbability,
254}
255
256#[derive(Debug, Serialize)]
257#[serde(rename_all = "camelCase")]
258pub struct CountTokensRequest {
259    pub contents: Vec<Content>,
260}
261
262#[derive(Debug, Deserialize)]
263#[serde(rename_all = "camelCase")]
264pub struct CountTokensResponse {
265    pub total_tokens: usize,
266}