@@ -1,6 +1,6 @@
mod supported_countries;
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, bail, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
@@ -15,6 +15,20 @@ pub async fn stream_generate_content(
api_key: &str,
mut request: GenerateContentRequest,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
+ if request.contents.is_empty() {
+ bail!("Request must contain at least one content item");
+ }
+
+ if let Some(user_content) = request
+ .contents
+ .iter()
+ .find(|content| content.role == Role::User)
+ {
+ if user_content.parts.is_empty() {
+ bail!("User content must contain at least one part");
+ }
+ }
+
let uri = format!(
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
model = request.model
@@ -140,7 +154,7 @@ pub struct Content {
pub role: Role,
}
-#[derive(Debug, Deserialize, Serialize)]
+#[derive(Debug, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum Role {
User,
@@ -291,6 +305,8 @@ pub enum Model {
Gemini15Pro,
#[serde(rename = "gemini-1.5-flash")]
Gemini15Flash,
+ #[serde(rename = "gemini-2.0-flash-exp")]
+ Gemini20Flash,
#[serde(rename = "custom")]
Custom {
name: String,
@@ -305,6 +321,7 @@ impl Model {
match self {
Model::Gemini15Pro => "gemini-1.5-pro",
Model::Gemini15Flash => "gemini-1.5-flash",
+ Model::Gemini20Flash => "gemini-2.0-flash-exp",
Model::Custom { name, .. } => name,
}
}
@@ -313,6 +330,7 @@ impl Model {
match self {
Model::Gemini15Pro => "Gemini 1.5 Pro",
Model::Gemini15Flash => "Gemini 1.5 Flash",
+ Model::Gemini20Flash => "Gemini 2.0 Flash",
Self::Custom {
name, display_name, ..
} => display_name.as_ref().unwrap_or(name),
@@ -323,6 +341,7 @@ impl Model {
match self {
Model::Gemini15Pro => 2_000_000,
Model::Gemini15Flash => 1_000_000,
+ Model::Gemini20Flash => 1_000_000,
Model::Custom { max_tokens, .. } => *max_tokens,
}
}