@@ -1,12 +1,21 @@
mod supported_countries;
use anyhow::{anyhow, Context, Result};
-use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
+use futures::{
+ io::BufReader,
+ stream::{self, BoxStream},
+ AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
+};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::Value;
-use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
+use std::{
+ convert::TryFrom,
+ future::{self, Future},
+ pin::Pin,
+ time::Duration,
+};
use strum::EnumIter;
pub use supported_countries::*;
@@ -72,6 +81,7 @@ pub enum Model {
display_name: Option<String>,
max_tokens: usize,
max_output_tokens: Option<u32>,
+ max_completion_tokens: Option<u32>,
},
}
@@ -139,6 +149,7 @@ pub struct Request {
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
@@ -263,6 +274,111 @@ pub struct ResponseStreamEvent {
pub usage: Option<Usage>,
}
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Response {
+ pub id: String,
+ pub object: String,
+ pub created: u64,
+ pub model: String,
+ pub choices: Vec<Choice>,
+ pub usage: Usage,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct Choice {
+ pub index: u32,
+ pub message: RequestMessage,
+ pub finish_reason: Option<String>,
+}
+
+pub async fn complete(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+ low_speed_timeout: Option<Duration>,
+) -> Result<Response> {
+ let uri = format!("{api_url}/chat/completions");
+ let mut request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key));
+ if let Some(low_speed_timeout) = low_speed_timeout {
+ request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
+ };
+
+ let mut request_body = request;
+ request_body.stream = false;
+
+ let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
+ let mut response = client.send(request).await?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ let response: Response = serde_json::from_str(&body)?;
+ Ok(response)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenAiResponse {
+ error: OpenAiError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenAiError {
+ message: String,
+ }
+
+ match serde_json::from_str::<OpenAiResponse>(&body) {
+ Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+ "Failed to connect to OpenAI API: {}",
+ response.error.message,
+ )),
+
+ _ => Err(anyhow!(
+ "Failed to connect to OpenAI API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
+
+fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
+ ResponseStreamEvent {
+ created: response.created as u32,
+ model: response.model,
+ choices: response
+ .choices
+ .into_iter()
+ .map(|choice| ChoiceDelta {
+ index: choice.index,
+ delta: ResponseMessageDelta {
+ role: Some(match choice.message {
+ RequestMessage::Assistant { .. } => Role::Assistant,
+ RequestMessage::User { .. } => Role::User,
+ RequestMessage::System { .. } => Role::System,
+ RequestMessage::Tool { .. } => Role::Tool,
+ }),
+ content: match choice.message {
+ RequestMessage::Assistant { content, .. } => content,
+ RequestMessage::User { content } => Some(content),
+ RequestMessage::System { content } => Some(content),
+ RequestMessage::Tool { content, .. } => Some(content),
+ },
+ tool_calls: None,
+ },
+ finish_reason: choice.finish_reason,
+ })
+ .collect(),
+ usage: Some(response.usage),
+ }
+}
+
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
@@ -270,6 +386,12 @@ pub async fn stream_completion(
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
+ if request.model == "o1-preview" || request.model == "o1-mini" {
+ let response = complete(client, api_url, api_key, request, low_speed_timeout).await;
+ let response_stream_event = response.map(adapt_response_to_stream);
+ return Ok(stream::once(future::ready(response_stream_event)).boxed());
+ }
+
let uri = format!("{api_url}/chat/completions");
let mut request_builder = HttpRequest::builder()
.method(Method::POST)