google_ai.rs

  1use std::sync::Arc;
  2
  3use anyhow::{anyhow, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  5use http::HttpClient;
  6use serde::{Deserialize, Serialize};
  7
  8pub const API_URL: &str = "https://generativelanguage.googleapis.com";
  9
 10pub async fn stream_generate_content(
 11    client: Arc<dyn HttpClient>,
 12    api_url: &str,
 13    api_key: &str,
 14    request: GenerateContentRequest,
 15) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 16    let uri = format!(
 17        "{}/v1beta/models/gemini-pro:streamGenerateContent?alt=sse&key={}",
 18        api_url, api_key
 19    );
 20
 21    let request = serde_json::to_string(&request)?;
 22    let mut response = client.post_json(&uri, request.into()).await?;
 23    if response.status().is_success() {
 24        let reader = BufReader::new(response.into_body());
 25        Ok(reader
 26            .lines()
 27            .filter_map(|line| async move {
 28                match line {
 29                    Ok(line) => {
 30                        if let Some(line) = line.strip_prefix("data: ") {
 31                            match serde_json::from_str(line) {
 32                                Ok(response) => Some(Ok(response)),
 33                                Err(error) => Some(Err(anyhow!(error))),
 34                            }
 35                        } else {
 36                            None
 37                        }
 38                    }
 39                    Err(error) => Some(Err(anyhow!(error))),
 40                }
 41            })
 42            .boxed())
 43    } else {
 44        let mut text = String::new();
 45        response.body_mut().read_to_string(&mut text).await?;
 46        Err(anyhow!(
 47            "error during streamGenerateContent, status code: {:?}, body: {}",
 48            response.status(),
 49            text
 50        ))
 51    }
 52}
 53
 54pub async fn count_tokens<T: HttpClient>(
 55    client: &T,
 56    api_url: &str,
 57    api_key: &str,
 58    request: CountTokensRequest,
 59) -> Result<CountTokensResponse> {
 60    let uri = format!(
 61        "{}/v1beta/models/gemini-pro:countTokens?key={}",
 62        api_url, api_key
 63    );
 64    let request = serde_json::to_string(&request)?;
 65    let mut response = client.post_json(&uri, request.into()).await?;
 66    let mut text = String::new();
 67    response.body_mut().read_to_string(&mut text).await?;
 68    if response.status().is_success() {
 69        Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 70    } else {
 71        Err(anyhow!(
 72            "error during countTokens, status code: {:?}, body: {}",
 73            response.status(),
 74            text
 75        ))
 76    }
 77}
 78
 79#[derive(Debug, Serialize, Deserialize)]
 80pub enum Task {
 81    #[serde(rename = "generateContent")]
 82    GenerateContent,
 83    #[serde(rename = "streamGenerateContent")]
 84    StreamGenerateContent,
 85    #[serde(rename = "countTokens")]
 86    CountTokens,
 87    #[serde(rename = "embedContent")]
 88    EmbedContent,
 89    #[serde(rename = "batchEmbedContents")]
 90    BatchEmbedContents,
 91}
 92
 93#[derive(Debug, Serialize)]
 94#[serde(rename_all = "camelCase")]
 95pub struct GenerateContentRequest {
 96    pub contents: Vec<Content>,
 97    pub generation_config: Option<GenerationConfig>,
 98    pub safety_settings: Option<Vec<SafetySetting>>,
 99}
100
101#[derive(Debug, Deserialize)]
102#[serde(rename_all = "camelCase")]
103pub struct GenerateContentResponse {
104    pub candidates: Option<Vec<GenerateContentCandidate>>,
105    pub prompt_feedback: Option<PromptFeedback>,
106}
107
108#[derive(Debug, Deserialize)]
109#[serde(rename_all = "camelCase")]
110pub struct GenerateContentCandidate {
111    pub index: usize,
112    pub content: Content,
113    pub finish_reason: Option<String>,
114    pub finish_message: Option<String>,
115    pub safety_ratings: Option<Vec<SafetyRating>>,
116    pub citation_metadata: Option<CitationMetadata>,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120#[serde(rename_all = "camelCase")]
121pub struct Content {
122    pub parts: Vec<Part>,
123    pub role: Role,
124}
125
126#[derive(Debug, Deserialize, Serialize)]
127#[serde(rename_all = "camelCase")]
128pub enum Role {
129    User,
130    Model,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134#[serde(untagged)]
135pub enum Part {
136    TextPart(TextPart),
137    InlineDataPart(InlineDataPart),
138}
139
140#[derive(Debug, Serialize, Deserialize)]
141#[serde(rename_all = "camelCase")]
142pub struct TextPart {
143    pub text: String,
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147#[serde(rename_all = "camelCase")]
148pub struct InlineDataPart {
149    pub inline_data: GenerativeContentBlob,
150}
151
152#[derive(Debug, Serialize, Deserialize)]
153#[serde(rename_all = "camelCase")]
154pub struct GenerativeContentBlob {
155    pub mime_type: String,
156    pub data: String,
157}
158
159#[derive(Debug, Deserialize)]
160#[serde(rename_all = "camelCase")]
161pub struct CitationSource {
162    pub start_index: Option<usize>,
163    pub end_index: Option<usize>,
164    pub uri: Option<String>,
165    pub license: Option<String>,
166}
167
168#[derive(Debug, Deserialize)]
169#[serde(rename_all = "camelCase")]
170pub struct CitationMetadata {
171    pub citation_sources: Vec<CitationSource>,
172}
173
174#[derive(Debug, Deserialize)]
175#[serde(rename_all = "camelCase")]
176pub struct PromptFeedback {
177    pub block_reason: Option<String>,
178    pub safety_ratings: Vec<SafetyRating>,
179    pub block_reason_message: Option<String>,
180}
181
182#[derive(Debug, Serialize)]
183#[serde(rename_all = "camelCase")]
184pub struct GenerationConfig {
185    pub candidate_count: Option<usize>,
186    pub stop_sequences: Option<Vec<String>>,
187    pub max_output_tokens: Option<usize>,
188    pub temperature: Option<f64>,
189    pub top_p: Option<f64>,
190    pub top_k: Option<usize>,
191}
192
193#[derive(Debug, Serialize)]
194#[serde(rename_all = "camelCase")]
195pub struct SafetySetting {
196    pub category: HarmCategory,
197    pub threshold: HarmBlockThreshold,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201pub enum HarmCategory {
202    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
203    Unspecified,
204    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
205    Derogatory,
206    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
207    Toxicity,
208    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
209    Violence,
210    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
211    Sexual,
212    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
213    Medical,
214    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
215    Dangerous,
216    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
217    Harassment,
218    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
219    HateSpeech,
220    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
221    SexuallyExplicit,
222    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
223    DangerousContent,
224}
225
226#[derive(Debug, Serialize)]
227pub enum HarmBlockThreshold {
228    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
229    Unspecified,
230    #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
231    BlockLowAndAbove,
232    #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
233    BlockMediumAndAbove,
234    #[serde(rename = "BLOCK_ONLY_HIGH")]
235    BlockOnlyHigh,
236    #[serde(rename = "BLOCK_NONE")]
237    BlockNone,
238}
239
240#[derive(Debug, Deserialize)]
241#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
242pub enum HarmProbability {
243    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
244    Unspecified,
245    Negligible,
246    Low,
247    Medium,
248    High,
249}
250
251#[derive(Debug, Deserialize)]
252#[serde(rename_all = "camelCase")]
253pub struct SafetyRating {
254    pub category: HarmCategory,
255    pub probability: HarmProbability,
256}
257
258#[derive(Debug, Serialize)]
259#[serde(rename_all = "camelCase")]
260pub struct CountTokensRequest {
261    pub contents: Vec<Content>,
262}
263
264#[derive(Debug, Deserialize)]
265#[serde(rename_all = "camelCase")]
266pub struct CountTokensResponse {
267    pub total_tokens: usize,
268}