google_ai.rs

  1use std::sync::Arc;
  2
  3use anyhow::{anyhow, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  5use http_client::HttpClient;
  6use serde::{Deserialize, Serialize};
  7
  8pub const API_URL: &str = "https://generativelanguage.googleapis.com";
  9
 10pub async fn stream_generate_content(
 11    client: Arc<dyn HttpClient>,
 12    api_url: &str,
 13    api_key: &str,
 14    model: &str,
 15    request: GenerateContentRequest,
 16) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 17    let uri = format!(
 18        "{}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={}",
 19        api_url, api_key
 20    );
 21
 22    let request = serde_json::to_string(&request)?;
 23    let mut response = client.post_json(&uri, request.into()).await?;
 24    if response.status().is_success() {
 25        let reader = BufReader::new(response.into_body());
 26        Ok(reader
 27            .lines()
 28            .filter_map(|line| async move {
 29                match line {
 30                    Ok(line) => {
 31                        if let Some(line) = line.strip_prefix("data: ") {
 32                            match serde_json::from_str(line) {
 33                                Ok(response) => Some(Ok(response)),
 34                                Err(error) => Some(Err(anyhow!(error))),
 35                            }
 36                        } else {
 37                            None
 38                        }
 39                    }
 40                    Err(error) => Some(Err(anyhow!(error))),
 41                }
 42            })
 43            .boxed())
 44    } else {
 45        let mut text = String::new();
 46        response.body_mut().read_to_string(&mut text).await?;
 47        Err(anyhow!(
 48            "error during streamGenerateContent, status code: {:?}, body: {}",
 49            response.status(),
 50            text
 51        ))
 52    }
 53}
 54
 55pub async fn count_tokens<T: HttpClient>(
 56    client: &T,
 57    api_url: &str,
 58    api_key: &str,
 59    request: CountTokensRequest,
 60) -> Result<CountTokensResponse> {
 61    let uri = format!(
 62        "{}/v1beta/models/gemini-pro:countTokens?key={}",
 63        api_url, api_key
 64    );
 65    let request = serde_json::to_string(&request)?;
 66    let mut response = client.post_json(&uri, request.into()).await?;
 67    let mut text = String::new();
 68    response.body_mut().read_to_string(&mut text).await?;
 69    if response.status().is_success() {
 70        Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 71    } else {
 72        Err(anyhow!(
 73            "error during countTokens, status code: {:?}, body: {}",
 74            response.status(),
 75            text
 76        ))
 77    }
 78}
 79
 80#[derive(Debug, Serialize, Deserialize)]
 81pub enum Task {
 82    #[serde(rename = "generateContent")]
 83    GenerateContent,
 84    #[serde(rename = "streamGenerateContent")]
 85    StreamGenerateContent,
 86    #[serde(rename = "countTokens")]
 87    CountTokens,
 88    #[serde(rename = "embedContent")]
 89    EmbedContent,
 90    #[serde(rename = "batchEmbedContents")]
 91    BatchEmbedContents,
 92}
 93
 94#[derive(Debug, Serialize)]
 95#[serde(rename_all = "camelCase")]
 96pub struct GenerateContentRequest {
 97    pub contents: Vec<Content>,
 98    pub generation_config: Option<GenerationConfig>,
 99    pub safety_settings: Option<Vec<SafetySetting>>,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct GenerateContentResponse {
105    pub candidates: Option<Vec<GenerateContentCandidate>>,
106    pub prompt_feedback: Option<PromptFeedback>,
107}
108
109#[derive(Debug, Deserialize)]
110#[serde(rename_all = "camelCase")]
111pub struct GenerateContentCandidate {
112    pub index: usize,
113    pub content: Content,
114    pub finish_reason: Option<String>,
115    pub finish_message: Option<String>,
116    pub safety_ratings: Option<Vec<SafetyRating>>,
117    pub citation_metadata: Option<CitationMetadata>,
118}
119
120#[derive(Debug, Serialize, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct Content {
123    pub parts: Vec<Part>,
124    pub role: Role,
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128#[serde(rename_all = "camelCase")]
129pub enum Role {
130    User,
131    Model,
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135#[serde(untagged)]
136pub enum Part {
137    TextPart(TextPart),
138    InlineDataPart(InlineDataPart),
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142#[serde(rename_all = "camelCase")]
143pub struct TextPart {
144    pub text: String,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct InlineDataPart {
150    pub inline_data: GenerativeContentBlob,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct GenerativeContentBlob {
156    pub mime_type: String,
157    pub data: String,
158}
159
160#[derive(Debug, Deserialize)]
161#[serde(rename_all = "camelCase")]
162pub struct CitationSource {
163    pub start_index: Option<usize>,
164    pub end_index: Option<usize>,
165    pub uri: Option<String>,
166    pub license: Option<String>,
167}
168
169#[derive(Debug, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct CitationMetadata {
172    pub citation_sources: Vec<CitationSource>,
173}
174
175#[derive(Debug, Deserialize)]
176#[serde(rename_all = "camelCase")]
177pub struct PromptFeedback {
178    pub block_reason: Option<String>,
179    pub safety_ratings: Vec<SafetyRating>,
180    pub block_reason_message: Option<String>,
181}
182
183#[derive(Debug, Serialize)]
184#[serde(rename_all = "camelCase")]
185pub struct GenerationConfig {
186    pub candidate_count: Option<usize>,
187    pub stop_sequences: Option<Vec<String>>,
188    pub max_output_tokens: Option<usize>,
189    pub temperature: Option<f64>,
190    pub top_p: Option<f64>,
191    pub top_k: Option<usize>,
192}
193
194#[derive(Debug, Serialize)]
195#[serde(rename_all = "camelCase")]
196pub struct SafetySetting {
197    pub category: HarmCategory,
198    pub threshold: HarmBlockThreshold,
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202pub enum HarmCategory {
203    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
204    Unspecified,
205    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
206    Derogatory,
207    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
208    Toxicity,
209    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
210    Violence,
211    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
212    Sexual,
213    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
214    Medical,
215    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
216    Dangerous,
217    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
218    Harassment,
219    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
220    HateSpeech,
221    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
222    SexuallyExplicit,
223    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
224    DangerousContent,
225}
226
227#[derive(Debug, Serialize)]
228pub enum HarmBlockThreshold {
229    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
230    Unspecified,
231    #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
232    BlockLowAndAbove,
233    #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
234    BlockMediumAndAbove,
235    #[serde(rename = "BLOCK_ONLY_HIGH")]
236    BlockOnlyHigh,
237    #[serde(rename = "BLOCK_NONE")]
238    BlockNone,
239}
240
241#[derive(Debug, Deserialize)]
242#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
243pub enum HarmProbability {
244    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
245    Unspecified,
246    Negligible,
247    Low,
248    Medium,
249    High,
250}
251
252#[derive(Debug, Deserialize)]
253#[serde(rename_all = "camelCase")]
254pub struct SafetyRating {
255    pub category: HarmCategory,
256    pub probability: HarmProbability,
257}
258
259#[derive(Debug, Serialize)]
260#[serde(rename_all = "camelCase")]
261pub struct CountTokensRequest {
262    pub contents: Vec<Content>,
263}
264
265#[derive(Debug, Deserialize)]
266#[serde(rename_all = "camelCase")]
267pub struct CountTokensResponse {
268    pub total_tokens: usize,
269}