google_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use serde::{Deserialize, Serialize};
  7
  8pub use supported_countries::*;
  9
 10pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 11
 12pub async fn stream_generate_content(
 13    client: &dyn HttpClient,
 14    api_url: &str,
 15    api_key: &str,
 16    mut request: GenerateContentRequest,
 17) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
 18    let uri = format!(
 19        "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
 20        model = request.model
 21    );
 22    request.model.clear();
 23
 24    let request_builder = HttpRequest::builder()
 25        .method(Method::POST)
 26        .uri(uri)
 27        .header("Content-Type", "application/json");
 28
 29    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
 30    let mut response = client.send(request).await?;
 31    if response.status().is_success() {
 32        let reader = BufReader::new(response.into_body());
 33        Ok(reader
 34            .lines()
 35            .filter_map(|line| async move {
 36                match line {
 37                    Ok(line) => {
 38                        if let Some(line) = line.strip_prefix("data: ") {
 39                            match serde_json::from_str(line) {
 40                                Ok(response) => Some(Ok(response)),
 41                                Err(error) => Some(Err(anyhow!(error))),
 42                            }
 43                        } else {
 44                            None
 45                        }
 46                    }
 47                    Err(error) => Some(Err(anyhow!(error))),
 48                }
 49            })
 50            .boxed())
 51    } else {
 52        let mut text = String::new();
 53        response.body_mut().read_to_string(&mut text).await?;
 54        Err(anyhow!(
 55            "error during streamGenerateContent, status code: {:?}, body: {}",
 56            response.status(),
 57            text
 58        ))
 59    }
 60}
 61
 62pub async fn count_tokens(
 63    client: &dyn HttpClient,
 64    api_url: &str,
 65    api_key: &str,
 66    request: CountTokensRequest,
 67) -> Result<CountTokensResponse> {
 68    let uri = format!(
 69        "{}/v1beta/models/gemini-pro:countTokens?key={}",
 70        api_url, api_key
 71    );
 72    let request = serde_json::to_string(&request)?;
 73
 74    let request_builder = HttpRequest::builder()
 75        .method(Method::POST)
 76        .uri(&uri)
 77        .header("Content-Type", "application/json");
 78
 79    let http_request = request_builder.body(AsyncBody::from(request))?;
 80    let mut response = client.send(http_request).await?;
 81    let mut text = String::new();
 82    response.body_mut().read_to_string(&mut text).await?;
 83    if response.status().is_success() {
 84        Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
 85    } else {
 86        Err(anyhow!(
 87            "error during countTokens, status code: {:?}, body: {}",
 88            response.status(),
 89            text
 90        ))
 91    }
 92}
 93
 94#[derive(Debug, Serialize, Deserialize)]
 95pub enum Task {
 96    #[serde(rename = "generateContent")]
 97    GenerateContent,
 98    #[serde(rename = "streamGenerateContent")]
 99    StreamGenerateContent,
100    #[serde(rename = "countTokens")]
101    CountTokens,
102    #[serde(rename = "embedContent")]
103    EmbedContent,
104    #[serde(rename = "batchEmbedContents")]
105    BatchEmbedContents,
106}
107
108#[derive(Debug, Serialize, Deserialize)]
109#[serde(rename_all = "camelCase")]
110pub struct GenerateContentRequest {
111    #[serde(default, skip_serializing_if = "String::is_empty")]
112    pub model: String,
113    pub contents: Vec<Content>,
114    pub generation_config: Option<GenerationConfig>,
115    pub safety_settings: Option<Vec<SafetySetting>>,
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119#[serde(rename_all = "camelCase")]
120pub struct GenerateContentResponse {
121    pub candidates: Option<Vec<GenerateContentCandidate>>,
122    pub prompt_feedback: Option<PromptFeedback>,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126#[serde(rename_all = "camelCase")]
127pub struct GenerateContentCandidate {
128    pub index: Option<usize>,
129    pub content: Content,
130    pub finish_reason: Option<String>,
131    pub finish_message: Option<String>,
132    pub safety_ratings: Option<Vec<SafetyRating>>,
133    pub citation_metadata: Option<CitationMetadata>,
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137#[serde(rename_all = "camelCase")]
138pub struct Content {
139    pub parts: Vec<Part>,
140    pub role: Role,
141}
142
143#[derive(Debug, Deserialize, Serialize)]
144#[serde(rename_all = "camelCase")]
145pub enum Role {
146    User,
147    Model,
148}
149
150#[derive(Debug, Serialize, Deserialize)]
151#[serde(untagged)]
152pub enum Part {
153    TextPart(TextPart),
154    InlineDataPart(InlineDataPart),
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158#[serde(rename_all = "camelCase")]
159pub struct TextPart {
160    pub text: String,
161}
162
163#[derive(Debug, Serialize, Deserialize)]
164#[serde(rename_all = "camelCase")]
165pub struct InlineDataPart {
166    pub inline_data: GenerativeContentBlob,
167}
168
169#[derive(Debug, Serialize, Deserialize)]
170#[serde(rename_all = "camelCase")]
171pub struct GenerativeContentBlob {
172    pub mime_type: String,
173    pub data: String,
174}
175
176#[derive(Debug, Serialize, Deserialize)]
177#[serde(rename_all = "camelCase")]
178pub struct CitationSource {
179    pub start_index: Option<usize>,
180    pub end_index: Option<usize>,
181    pub uri: Option<String>,
182    pub license: Option<String>,
183}
184
185#[derive(Debug, Serialize, Deserialize)]
186#[serde(rename_all = "camelCase")]
187pub struct CitationMetadata {
188    pub citation_sources: Vec<CitationSource>,
189}
190
191#[derive(Debug, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct PromptFeedback {
194    pub block_reason: Option<String>,
195    pub safety_ratings: Vec<SafetyRating>,
196    pub block_reason_message: Option<String>,
197}
198
199#[derive(Debug, Deserialize, Serialize)]
200#[serde(rename_all = "camelCase")]
201pub struct GenerationConfig {
202    pub candidate_count: Option<usize>,
203    pub stop_sequences: Option<Vec<String>>,
204    pub max_output_tokens: Option<usize>,
205    pub temperature: Option<f64>,
206    pub top_p: Option<f64>,
207    pub top_k: Option<usize>,
208}
209
210#[derive(Debug, Serialize, Deserialize)]
211#[serde(rename_all = "camelCase")]
212pub struct SafetySetting {
213    pub category: HarmCategory,
214    pub threshold: HarmBlockThreshold,
215}
216
217#[derive(Debug, Serialize, Deserialize)]
218pub enum HarmCategory {
219    #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
220    Unspecified,
221    #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
222    Derogatory,
223    #[serde(rename = "HARM_CATEGORY_TOXICITY")]
224    Toxicity,
225    #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
226    Violence,
227    #[serde(rename = "HARM_CATEGORY_SEXUAL")]
228    Sexual,
229    #[serde(rename = "HARM_CATEGORY_MEDICAL")]
230    Medical,
231    #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
232    Dangerous,
233    #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
234    Harassment,
235    #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
236    HateSpeech,
237    #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
238    SexuallyExplicit,
239    #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
240    DangerousContent,
241}
242
243#[derive(Debug, Serialize, Deserialize)]
244pub enum HarmBlockThreshold {
245    #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
246    Unspecified,
247    #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
248    BlockLowAndAbove,
249    #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
250    BlockMediumAndAbove,
251    #[serde(rename = "BLOCK_ONLY_HIGH")]
252    BlockOnlyHigh,
253    #[serde(rename = "BLOCK_NONE")]
254    BlockNone,
255}
256
257#[derive(Debug, Serialize, Deserialize)]
258#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
259pub enum HarmProbability {
260    #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
261    Unspecified,
262    Negligible,
263    Low,
264    Medium,
265    High,
266}
267
268#[derive(Debug, Serialize, Deserialize)]
269#[serde(rename_all = "camelCase")]
270pub struct SafetyRating {
271    pub category: HarmCategory,
272    pub probability: HarmProbability,
273}
274
275#[derive(Debug, Serialize, Deserialize)]
276#[serde(rename_all = "camelCase")]
277pub struct CountTokensRequest {
278    pub contents: Vec<Content>,
279}
280
281#[derive(Debug, Serialize, Deserialize)]
282#[serde(rename_all = "camelCase")]
283pub struct CountTokensResponse {
284    pub total_tokens: usize,
285}
286
287#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
288#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
289pub enum Model {
290    #[serde(rename = "gemini-1.5-pro")]
291    Gemini15Pro,
292    #[serde(rename = "gemini-1.5-flash")]
293    Gemini15Flash,
294    #[serde(rename = "custom")]
295    Custom {
296        name: String,
297        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
298        display_name: Option<String>,
299        max_tokens: usize,
300    },
301}
302
303impl Model {
304    pub fn id(&self) -> &str {
305        match self {
306            Model::Gemini15Pro => "gemini-1.5-pro",
307            Model::Gemini15Flash => "gemini-1.5-flash",
308            Model::Custom { name, .. } => name,
309        }
310    }
311
312    pub fn display_name(&self) -> &str {
313        match self {
314            Model::Gemini15Pro => "Gemini 1.5 Pro",
315            Model::Gemini15Flash => "Gemini 1.5 Flash",
316            Self::Custom {
317                name, display_name, ..
318            } => display_name.as_ref().unwrap_or(name),
319        }
320    }
321
322    pub fn max_token_count(&self) -> usize {
323        match self {
324            Model::Gemini15Pro => 2_000_000,
325            Model::Gemini15Flash => 1_000_000,
326            Model::Custom { max_tokens, .. } => *max_tokens,
327        }
328    }
329}
330
331impl std::fmt::Display for Model {
332    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333        write!(f, "{}", self.id())
334    }
335}
336
337pub fn extract_text_from_events(
338    events: impl Stream<Item = Result<GenerateContentResponse>>,
339) -> impl Stream<Item = Result<String>> {
340    events.filter_map(|event| async move {
341        match event {
342            Ok(event) => event.candidates.and_then(|candidates| {
343                candidates.into_iter().next().and_then(|candidate| {
344                    candidate.content.parts.into_iter().next().and_then(|part| {
345                        if let Part::TextPart(TextPart { text }) = part {
346                            Some(Ok(text))
347                        } else {
348                            None
349                        }
350                    })
351                })
352            }),
353            Err(error) => Some(Err(error)),
354        }
355    })
356}