google_ai.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  3use http_client::HttpClient;
  4use serde::{Deserialize, Serialize};
  5
  6pub const API_URL: &str = "https://generativelanguage.googleapis.com";
  7
  8pub async fn stream_generate_content(
  9    client: &dyn HttpClient,
 10    api_url: &str,
 11    api_key: &str,
 12    mut request: GenerateContentRequest,
 13) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 14    let uri = format!(
 15        "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
 16        model = request.model
 17    );
 18    request.model.clear();
 19
 20    let request = serde_json::to_string(&request)?;
 21    let mut response = client.post_json(&uri, request.into()).await?;
 22    if response.status().is_success() {
 23        let reader = BufReader::new(response.into_body());
 24        Ok(reader
 25            .lines()
 26            .filter_map(|line| async move {
 27                match line {
 28                    Ok(line) => {
 29                        if let Some(line) = line.strip_prefix("data: ") {
 30                            match serde_json::from_str(line) {
 31                                Ok(response) => Some(Ok(response)),
 32                                Err(error) => Some(Err(anyhow!(error))),
 33                            }
 34                        } else {
 35                            None
 36                        }
 37                    }
 38                    Err(error) => Some(Err(anyhow!(error))),
 39                }
 40            })
 41            .boxed())
 42    } else {
 43        let mut text = String::new();
 44        response.body_mut().read_to_string(&mut text).await?;
 45        Err(anyhow!(
 46            "error during streamGenerateContent, status code: {:?}, body: {}",
 47            response.status(),
 48            text
 49        ))
 50    }
 51}
 52
 53pub async fn count_tokens(
 54    client: &dyn HttpClient,
 55    api_url: &str,
 56    api_key: &str,
 57    request: CountTokensRequest,
 58) -> Result<CountTokensResponse> {
 59    let uri = format!(
 60        "{}/v1beta/models/gemini-pro:countTokens?key={}",
 61        api_url, api_key
 62    );
 63    let request = serde_json::to_string(&request)?;
 64    let mut response = client.post_json(&uri, request.into()).await?;
 65    let mut text = String::new();
 66    response.body_mut().read_to_string(&mut text).await?;
 67    if response.status().is_success() {
 68        Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 69    } else {
 70        Err(anyhow!(
 71            "error during countTokens, status code: {:?}, body: {}",
 72            response.status(),
 73            text
 74        ))
 75    }
 76}
 77
 78#[derive(Debug, Serialize, Deserialize)]
 79pub enum Task {
 80    #[serde(rename = "generateContent")]
 81    GenerateContent,
 82    #[serde(rename = "streamGenerateContent")]
 83    StreamGenerateContent,
 84    #[serde(rename = "countTokens")]
 85    CountTokens,
 86    #[serde(rename = "embedContent")]
 87    EmbedContent,
 88    #[serde(rename = "batchEmbedContents")]
 89    BatchEmbedContents,
 90}
 91
 92#[derive(Debug, Serialize, Deserialize)]
 93#[serde(rename_all = "camelCase")]
 94pub struct GenerateContentRequest {
 95    #[serde(default, skip_serializing_if = "String::is_empty")]
 96    pub model: String,
 97    pub contents: Vec<Content>,
 98    pub generation_config: Option<GenerationConfig>,
 99    pub safety_settings: Option<Vec<SafetySetting>>,
100}
101
102#[derive(Debug, Serialize, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct GenerateContentResponse {
105    pub candidates: Option<Vec<GenerateContentCandidate>>,
106    pub prompt_feedback: Option<PromptFeedback>,
107}
108
109#[derive(Debug, Serialize, Deserialize)]
110#[serde(rename_all = "camelCase")]
111pub struct GenerateContentCandidate {
112    pub index: usize,
113    pub content: Content,
114    pub finish_reason: Option<String>,
115    pub finish_message: Option<String>,
116    pub safety_ratings: Option<Vec<SafetyRating>>,
117    pub citation_metadata: Option<CitationMetadata>,
118}
119
120#[derive(Debug, Serialize, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct Content {
123    pub parts: Vec<Part>,
124    pub role: Role,
125}
126
127#[derive(Debug, Deserialize, Serialize)]
128#[serde(rename_all = "camelCase")]
129pub enum Role {
130    User,
131    Model,
132}
133
134#[derive(Debug, Serialize, Deserialize)]
135#[serde(untagged)]
136pub enum Part {
137    TextPart(TextPart),
138    InlineDataPart(InlineDataPart),
139}
140
141#[derive(Debug, Serialize, Deserialize)]
142#[serde(rename_all = "camelCase")]
143pub struct TextPart {
144    pub text: String,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct InlineDataPart {
150    pub inline_data: GenerativeContentBlob,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154#[serde(rename_all = "camelCase")]
155pub struct GenerativeContentBlob {
156    pub mime_type: String,
157    pub data: String,
158}
159
160#[derive(Debug, Serialize, Deserialize)]
161#[serde(rename_all = "camelCase")]
162pub struct CitationSource {
163    pub start_index: Option<usize>,
164    pub end_index: Option<usize>,
165    pub uri: Option<String>,
166    pub license: Option<String>,
167}
168
169#[derive(Debug, Serialize, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct CitationMetadata {
172    pub citation_sources: Vec<CitationSource>,
173}
174
175#[derive(Debug, Serialize, Deserialize)]
176#[serde(rename_all = "camelCase")]
177pub struct PromptFeedback {
178    pub block_reason: Option<String>,
179    pub safety_ratings: Vec<SafetyRating>,
180    pub block_reason_message: Option<String>,
181}
182
183#[derive(Debug, Deserialize, Serialize)]
184#[serde(rename_all = "camelCase")]
185pub struct GenerationConfig {
186    pub candidate_count: Option<usize>,
187    pub stop_sequences: Option<Vec<String>>,
188    pub max_output_tokens: Option<usize>,
189    pub temperature: Option<f64>,
190    pub top_p: Option<f64>,
191    pub top_k: Option<usize>,
192}
193
194#[derive(Debug, Serialize, Deserialize)]
195#[serde(rename_all = "camelCase")]
196pub struct SafetySetting {
197    pub category: HarmCategory,
198    pub threshold: HarmBlockThreshold,
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202pub enum HarmCategory {
203    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
204    Unspecified,
205    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
206    Derogatory,
207    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
208    Toxicity,
209    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
210    Violence,
211    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
212    Sexual,
213    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
214    Medical,
215    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
216    Dangerous,
217    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
218    Harassment,
219    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
220    HateSpeech,
221    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
222    SexuallyExplicit,
223    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
224    DangerousContent,
225}
226
227#[derive(Debug, Serialize, Deserialize)]
228pub enum HarmBlockThreshold {
229    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
230    Unspecified,
231    #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
232    BlockLowAndAbove,
233    #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
234    BlockMediumAndAbove,
235    #[serde(rename = "BLOCK_ONLY_HIGH")]
236    BlockOnlyHigh,
237    #[serde(rename = "BLOCK_NONE")]
238    BlockNone,
239}
240
241#[derive(Debug, Serialize, Deserialize)]
242#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
243pub enum HarmProbability {
244    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
245    Unspecified,
246    Negligible,
247    Low,
248    Medium,
249    High,
250}
251
252#[derive(Debug, Serialize, Deserialize)]
253#[serde(rename_all = "camelCase")]
254pub struct SafetyRating {
255    pub category: HarmCategory,
256    pub probability: HarmProbability,
257}
258
259#[derive(Debug, Serialize, Deserialize)]
260#[serde(rename_all = "camelCase")]
261pub struct CountTokensRequest {
262    pub contents: Vec<Content>,
263}
264
265#[derive(Debug, Serialize, Deserialize)]
266#[serde(rename_all = "camelCase")]
267pub struct CountTokensResponse {
268    pub total_tokens: usize,
269}
270
271#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
272#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
273pub enum Model {
274    #[serde(rename = "gemini-1.5-pro")]
275    Gemini15Pro,
276    #[serde(rename = "gemini-1.5-flash")]
277    Gemini15Flash,
278    #[serde(rename = "custom")]
279    Custom { name: String, max_tokens: usize },
280}
281
282impl Model {
283    pub fn id(&self) -> &str {
284        match self {
285            Model::Gemini15Pro => "gemini-1.5-pro",
286            Model::Gemini15Flash => "gemini-1.5-flash",
287            Model::Custom { name, .. } => name,
288        }
289    }
290
291    pub fn display_name(&self) -> &str {
292        match self {
293            Model::Gemini15Pro => "Gemini 1.5 Pro",
294            Model::Gemini15Flash => "Gemini 1.5 Flash",
295            Model::Custom { name, .. } => name,
296        }
297    }
298
299    pub fn max_token_count(&self) -> usize {
300        match self {
301            Model::Gemini15Pro => 2_000_000,
302            Model::Gemini15Flash => 1_000_000,
303            Model::Custom { max_tokens, .. } => *max_tokens,
304        }
305    }
306}
307
308impl std::fmt::Display for Model {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        write!(f, "{}", self.id())
311    }
312}
313
314pub fn extract_text_from_events(
315    events: impl Stream<Item = Result<GenerateContentResponse>>,
316) -> impl Stream<Item = Result<String>> {
317    events.filter_map(|event| async move {
318        match event {
319            Ok(event) => event.candidates.and_then(|candidates| {
320                candidates.into_iter().next().and_then(|candidate| {
321                    candidate.content.parts.into_iter().next().and_then(|part| {
322                        if let Part::TextPart(TextPart { text }) = part {
323                            Some(Ok(text))
324                        } else {
325                            None
326                        }
327                    })
328                })
329            }),
330            Err(error) => Some(Err(error)),
331        }
332    })
333}