google_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  5use http_client::HttpClient;
  6use serde::{Deserialize, Serialize};
  7
  8pub use supported_countries::*;
  9
 10pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 11
 12pub async fn stream_generate_content(
 13    client: &dyn HttpClient,
 14    api_url: &str,
 15    api_key: &str,
 16    mut request: GenerateContentRequest,
 17) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 18    let uri = format!(
 19        "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
 20        model = request.model
 21    );
 22    request.model.clear();
 23
 24    let request = serde_json::to_string(&request)?;
 25    let mut response = client.post_json(&uri, request.into()).await?;
 26    if response.status().is_success() {
 27        let reader = BufReader::new(response.into_body());
 28        Ok(reader
 29            .lines()
 30            .filter_map(|line| async move {
 31                match line {
 32                    Ok(line) => {
 33                        if let Some(line) = line.strip_prefix("data: ") {
 34                            match serde_json::from_str(line) {
 35                                Ok(response) => Some(Ok(response)),
 36                                Err(error) => Some(Err(anyhow!(error))),
 37                            }
 38                        } else {
 39                            None
 40                        }
 41                    }
 42                    Err(error) => Some(Err(anyhow!(error))),
 43                }
 44            })
 45            .boxed())
 46    } else {
 47        let mut text = String::new();
 48        response.body_mut().read_to_string(&mut text).await?;
 49        Err(anyhow!(
 50            "error during streamGenerateContent, status code: {:?}, body: {}",
 51            response.status(),
 52            text
 53        ))
 54    }
 55}
 56
 57pub async fn count_tokens(
 58    client: &dyn HttpClient,
 59    api_url: &str,
 60    api_key: &str,
 61    request: CountTokensRequest,
 62) -> Result<CountTokensResponse> {
 63    let uri = format!(
 64        "{}/v1beta/models/gemini-pro:countTokens?key={}",
 65        api_url, api_key
 66    );
 67    let request = serde_json::to_string(&request)?;
 68    let mut response = client.post_json(&uri, request.into()).await?;
 69    let mut text = String::new();
 70    response.body_mut().read_to_string(&mut text).await?;
 71    if response.status().is_success() {
 72        Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 73    } else {
 74        Err(anyhow!(
 75            "error during countTokens, status code: {:?}, body: {}",
 76            response.status(),
 77            text
 78        ))
 79    }
 80}
 81
 82#[derive(Debug, Serialize, Deserialize)]
 83pub enum Task {
 84    #[serde(rename = "generateContent")]
 85    GenerateContent,
 86    #[serde(rename = "streamGenerateContent")]
 87    StreamGenerateContent,
 88    #[serde(rename = "countTokens")]
 89    CountTokens,
 90    #[serde(rename = "embedContent")]
 91    EmbedContent,
 92    #[serde(rename = "batchEmbedContents")]
 93    BatchEmbedContents,
 94}
 95
 96#[derive(Debug, Serialize, Deserialize)]
 97#[serde(rename_all = "camelCase")]
 98pub struct GenerateContentRequest {
 99    #[serde(default, skip_serializing_if = "String::is_empty")]
100    pub model: String,
101    pub contents: Vec<Content>,
102    pub generation_config: Option<GenerationConfig>,
103    pub safety_settings: Option<Vec<SafetySetting>>,
104}
105
106#[derive(Debug, Serialize, Deserialize)]
107#[serde(rename_all = "camelCase")]
108pub struct GenerateContentResponse {
109    pub candidates: Option<Vec<GenerateContentCandidate>>,
110    pub prompt_feedback: Option<PromptFeedback>,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct GenerateContentCandidate {
116    pub index: usize,
117    pub content: Content,
118    pub finish_reason: Option<String>,
119    pub finish_message: Option<String>,
120    pub safety_ratings: Option<Vec<SafetyRating>>,
121    pub citation_metadata: Option<CitationMetadata>,
122}
123
124#[derive(Debug, Serialize, Deserialize)]
125#[serde(rename_all = "camelCase")]
126pub struct Content {
127    pub parts: Vec<Part>,
128    pub role: Role,
129}
130
131#[derive(Debug, Deserialize, Serialize)]
132#[serde(rename_all = "camelCase")]
133pub enum Role {
134    User,
135    Model,
136}
137
138#[derive(Debug, Serialize, Deserialize)]
139#[serde(untagged)]
140pub enum Part {
141    TextPart(TextPart),
142    InlineDataPart(InlineDataPart),
143}
144
145#[derive(Debug, Serialize, Deserialize)]
146#[serde(rename_all = "camelCase")]
147pub struct TextPart {
148    pub text: String,
149}
150
151#[derive(Debug, Serialize, Deserialize)]
152#[serde(rename_all = "camelCase")]
153pub struct InlineDataPart {
154    pub inline_data: GenerativeContentBlob,
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158#[serde(rename_all = "camelCase")]
159pub struct GenerativeContentBlob {
160    pub mime_type: String,
161    pub data: String,
162}
163
164#[derive(Debug, Serialize, Deserialize)]
165#[serde(rename_all = "camelCase")]
166pub struct CitationSource {
167    pub start_index: Option<usize>,
168    pub end_index: Option<usize>,
169    pub uri: Option<String>,
170    pub license: Option<String>,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174#[serde(rename_all = "camelCase")]
175pub struct CitationMetadata {
176    pub citation_sources: Vec<CitationSource>,
177}
178
179#[derive(Debug, Serialize, Deserialize)]
180#[serde(rename_all = "camelCase")]
181pub struct PromptFeedback {
182    pub block_reason: Option<String>,
183    pub safety_ratings: Vec<SafetyRating>,
184    pub block_reason_message: Option<String>,
185}
186
187#[derive(Debug, Deserialize, Serialize)]
188#[serde(rename_all = "camelCase")]
189pub struct GenerationConfig {
190    pub candidate_count: Option<usize>,
191    pub stop_sequences: Option<Vec<String>>,
192    pub max_output_tokens: Option<usize>,
193    pub temperature: Option<f64>,
194    pub top_p: Option<f64>,
195    pub top_k: Option<usize>,
196}
197
198#[derive(Debug, Serialize, Deserialize)]
199#[serde(rename_all = "camelCase")]
200pub struct SafetySetting {
201    pub category: HarmCategory,
202    pub threshold: HarmBlockThreshold,
203}
204
205#[derive(Debug, Serialize, Deserialize)]
206pub enum HarmCategory {
207    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
208    Unspecified,
209    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
210    Derogatory,
211    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
212    Toxicity,
213    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
214    Violence,
215    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
216    Sexual,
217    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
218    Medical,
219    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
220    Dangerous,
221    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
222    Harassment,
223    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
224    HateSpeech,
225    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
226    SexuallyExplicit,
227    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
228    DangerousContent,
229}
230
231#[derive(Debug, Serialize, Deserialize)]
232pub enum HarmBlockThreshold {
233    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
234    Unspecified,
235    #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
236    BlockLowAndAbove,
237    #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
238    BlockMediumAndAbove,
239    #[serde(rename = "BLOCK_ONLY_HIGH")]
240    BlockOnlyHigh,
241    #[serde(rename = "BLOCK_NONE")]
242    BlockNone,
243}
244
245#[derive(Debug, Serialize, Deserialize)]
246#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
247pub enum HarmProbability {
248    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
249    Unspecified,
250    Negligible,
251    Low,
252    Medium,
253    High,
254}
255
256#[derive(Debug, Serialize, Deserialize)]
257#[serde(rename_all = "camelCase")]
258pub struct SafetyRating {
259    pub category: HarmCategory,
260    pub probability: HarmProbability,
261}
262
263#[derive(Debug, Serialize, Deserialize)]
264#[serde(rename_all = "camelCase")]
265pub struct CountTokensRequest {
266    pub contents: Vec<Content>,
267}
268
269#[derive(Debug, Serialize, Deserialize)]
270#[serde(rename_all = "camelCase")]
271pub struct CountTokensResponse {
272    pub total_tokens: usize,
273}
274
275#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
276#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
277pub enum Model {
278    #[serde(rename = "gemini-1.5-pro")]
279    Gemini15Pro,
280    #[serde(rename = "gemini-1.5-flash")]
281    Gemini15Flash,
282    #[serde(rename = "custom")]
283    Custom { name: String, max_tokens: usize },
284}
285
286impl Model {
287    pub fn id(&self) -> &str {
288        match self {
289            Model::Gemini15Pro => "gemini-1.5-pro",
290            Model::Gemini15Flash => "gemini-1.5-flash",
291            Model::Custom { name, .. } => name,
292        }
293    }
294
295    pub fn display_name(&self) -> &str {
296        match self {
297            Model::Gemini15Pro => "Gemini 1.5 Pro",
298            Model::Gemini15Flash => "Gemini 1.5 Flash",
299            Model::Custom { name, .. } => name,
300        }
301    }
302
303    pub fn max_token_count(&self) -> usize {
304        match self {
305            Model::Gemini15Pro => 2_000_000,
306            Model::Gemini15Flash => 1_000_000,
307            Model::Custom { max_tokens, .. } => *max_tokens,
308        }
309    }
310}
311
312impl std::fmt::Display for Model {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        write!(f, "{}", self.id())
315    }
316}
317
318pub fn extract_text_from_events(
319    events: impl Stream<Item = Result<GenerateContentResponse>>,
320) -> impl Stream<Item = Result<String>> {
321    events.filter_map(|event| async move {
322        match event {
323            Ok(event) => event.candidates.and_then(|candidates| {
324                candidates.into_iter().next().and_then(|candidate| {
325                    candidate.content.parts.into_iter().next().and_then(|part| {
326                        if let Part::TextPart(TextPart { text }) = part {
327                            Some(Ok(text))
328                        } else {
329                            None
330                        }
331                    })
332                })
333            }),
334            Err(error) => Some(Err(error)),
335        }
336    })
337}