open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use isahc::config::Configurable;
  7use serde::{Deserialize, Serialize};
  8use serde_json::Value;
  9use std::{convert::TryFrom, future::Future, time::Duration};
 10use strum::EnumIter;
 11
 12pub use supported_countries::*;
 13
 14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 15
 16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 17    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 18}
 19
 20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 21#[serde(rename_all = "lowercase")]
 22pub enum Role {
 23    User,
 24    Assistant,
 25    System,
 26    Tool,
 27}
 28
 29impl TryFrom<String> for Role {
 30    type Error = anyhow::Error;
 31
 32    fn try_from(value: String) -> Result<Self> {
 33        match value.as_str() {
 34            "user" => Ok(Self::User),
 35            "assistant" => Ok(Self::Assistant),
 36            "system" => Ok(Self::System),
 37            "tool" => Ok(Self::Tool),
 38            _ => Err(anyhow!("invalid role '{value}'")),
 39        }
 40    }
 41}
 42
 43impl From<Role> for String {
 44    fn from(val: Role) -> Self {
 45        match val {
 46            Role::User => "user".to_owned(),
 47            Role::Assistant => "assistant".to_owned(),
 48            Role::System => "system".to_owned(),
 49            Role::Tool => "tool".to_owned(),
 50        }
 51    }
 52}
 53
 54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 56pub enum Model {
 57    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 58    ThreePointFiveTurbo,
 59    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 60    Four,
 61    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 62    FourTurbo,
 63    #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
 64    #[default]
 65    FourOmni,
 66    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")]
 67    FourOmniMini,
 68    #[serde(rename = "custom")]
 69    Custom { name: String, max_tokens: usize },
 70}
 71
 72impl Model {
 73    pub fn from_id(id: &str) -> Result<Self> {
 74        match id {
 75            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 76            "gpt-4" => Ok(Self::Four),
 77            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 78            "gpt-4o" => Ok(Self::FourOmni),
 79            "gpt-4o-mini" => Ok(Self::FourOmniMini),
 80            _ => Err(anyhow!("invalid model id")),
 81        }
 82    }
 83
 84    pub fn id(&self) -> &str {
 85        match self {
 86            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 87            Self::Four => "gpt-4",
 88            Self::FourTurbo => "gpt-4-turbo-preview",
 89            Self::FourOmni => "gpt-4o",
 90            Self::FourOmniMini => "gpt-4o-mini",
 91            Self::Custom { name, .. } => name,
 92        }
 93    }
 94
 95    pub fn display_name(&self) -> &str {
 96        match self {
 97            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 98            Self::Four => "gpt-4",
 99            Self::FourTurbo => "gpt-4-turbo",
100            Self::FourOmni => "gpt-4o",
101            Self::FourOmniMini => "gpt-4o-mini",
102            Self::Custom { name, .. } => name,
103        }
104    }
105
106    pub fn max_token_count(&self) -> usize {
107        match self {
108            Self::ThreePointFiveTurbo => 4096,
109            Self::Four => 8192,
110            Self::FourTurbo => 128000,
111            Self::FourOmni => 128000,
112            Self::FourOmniMini => 128000,
113            Self::Custom { max_tokens, .. } => *max_tokens,
114        }
115    }
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119pub struct Request {
120    pub model: String,
121    pub messages: Vec<RequestMessage>,
122    pub stream: bool,
123    #[serde(default, skip_serializing_if = "Option::is_none")]
124    pub max_tokens: Option<usize>,
125    pub stop: Vec<String>,
126    pub temperature: f32,
127    #[serde(default, skip_serializing_if = "Option::is_none")]
128    pub tool_choice: Option<ToolChoice>,
129    #[serde(default, skip_serializing_if = "Vec::is_empty")]
130    pub tools: Vec<ToolDefinition>,
131}
132
133#[derive(Debug, Serialize, Deserialize)]
134#[serde(untagged)]
135pub enum ToolChoice {
136    Auto,
137    Required,
138    None,
139    Other(ToolDefinition),
140}
141
142#[derive(Clone, Deserialize, Serialize, Debug)]
143#[serde(tag = "type", rename_all = "snake_case")]
144pub enum ToolDefinition {
145    #[allow(dead_code)]
146    Function { function: FunctionDefinition },
147}
148
149#[derive(Clone, Debug, Serialize, Deserialize)]
150pub struct FunctionDefinition {
151    pub name: String,
152    pub description: Option<String>,
153    pub parameters: Option<Value>,
154}
155
156#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
157#[serde(tag = "role", rename_all = "lowercase")]
158pub enum RequestMessage {
159    Assistant {
160        content: Option<String>,
161        #[serde(default, skip_serializing_if = "Vec::is_empty")]
162        tool_calls: Vec<ToolCall>,
163    },
164    User {
165        content: String,
166    },
167    System {
168        content: String,
169    },
170    Tool {
171        content: String,
172        tool_call_id: String,
173    },
174}
175
176#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
177pub struct ToolCall {
178    pub id: String,
179    #[serde(flatten)]
180    pub content: ToolCallContent,
181}
182
183#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
184#[serde(tag = "type", rename_all = "lowercase")]
185pub enum ToolCallContent {
186    Function { function: FunctionContent },
187}
188
189#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
190pub struct FunctionContent {
191    pub name: String,
192    pub arguments: String,
193}
194
195#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
196pub struct ResponseMessageDelta {
197    pub role: Option<Role>,
198    pub content: Option<String>,
199    #[serde(default, skip_serializing_if = "is_none_or_empty")]
200    pub tool_calls: Option<Vec<ToolCallChunk>>,
201}
202
203#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
204pub struct ToolCallChunk {
205    pub index: usize,
206    pub id: Option<String>,
207
208    // There is also an optional `type` field that would determine if a
209    // function is there. Sometimes this streams in with the `function` before
210    // it streams in the `type`
211    pub function: Option<FunctionChunk>,
212}
213
214#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
215pub struct FunctionChunk {
216    pub name: Option<String>,
217    pub arguments: Option<String>,
218}
219
220#[derive(Serialize, Deserialize, Debug)]
221pub struct Usage {
222    pub prompt_tokens: u32,
223    pub completion_tokens: u32,
224    pub total_tokens: u32,
225}
226
227#[derive(Serialize, Deserialize, Debug)]
228pub struct ChoiceDelta {
229    pub index: u32,
230    pub delta: ResponseMessageDelta,
231    pub finish_reason: Option<String>,
232}
233
234#[derive(Serialize, Deserialize, Debug)]
235#[serde(untagged)]
236pub enum ResponseStreamResult {
237    Ok(ResponseStreamEvent),
238    Err { error: String },
239}
240
241#[derive(Serialize, Deserialize, Debug)]
242pub struct ResponseStreamEvent {
243    pub created: u32,
244    pub model: String,
245    pub choices: Vec<ChoiceDelta>,
246    pub usage: Option<Usage>,
247}
248
249pub async fn stream_completion(
250    client: &dyn HttpClient,
251    api_url: &str,
252    api_key: &str,
253    request: Request,
254    low_speed_timeout: Option<Duration>,
255) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
256    let uri = format!("{api_url}/chat/completions");
257    let mut request_builder = HttpRequest::builder()
258        .method(Method::POST)
259        .uri(uri)
260        .header("Content-Type", "application/json")
261        .header("Authorization", format!("Bearer {}", api_key));
262
263    if let Some(low_speed_timeout) = low_speed_timeout {
264        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
265    };
266
267    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
268    let mut response = client.send(request).await?;
269    if response.status().is_success() {
270        let reader = BufReader::new(response.into_body());
271        Ok(reader
272            .lines()
273            .filter_map(|line| async move {
274                match line {
275                    Ok(line) => {
276                        let line = line.strip_prefix("data: ")?;
277                        if line == "[DONE]" {
278                            None
279                        } else {
280                            match serde_json::from_str(line) {
281                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
282                                Ok(ResponseStreamResult::Err { error }) => {
283                                    Some(Err(anyhow!(error)))
284                                }
285                                Err(error) => Some(Err(anyhow!(error))),
286                            }
287                        }
288                    }
289                    Err(error) => Some(Err(anyhow!(error))),
290                }
291            })
292            .boxed())
293    } else {
294        let mut body = String::new();
295        response.body_mut().read_to_string(&mut body).await?;
296
297        #[derive(Deserialize)]
298        struct OpenAiResponse {
299            error: OpenAiError,
300        }
301
302        #[derive(Deserialize)]
303        struct OpenAiError {
304            message: String,
305        }
306
307        match serde_json::from_str::<OpenAiResponse>(&body) {
308            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
309                "Failed to connect to OpenAI API: {}",
310                response.error.message,
311            )),
312
313            _ => Err(anyhow!(
314                "Failed to connect to OpenAI API: {} {}",
315                response.status(),
316                body,
317            )),
318        }
319    }
320}
321
322#[derive(Copy, Clone, Serialize, Deserialize)]
323pub enum OpenAiEmbeddingModel {
324    #[serde(rename = "text-embedding-3-small")]
325    TextEmbedding3Small,
326    #[serde(rename = "text-embedding-3-large")]
327    TextEmbedding3Large,
328}
329
330#[derive(Serialize)]
331struct OpenAiEmbeddingRequest<'a> {
332    model: OpenAiEmbeddingModel,
333    input: Vec<&'a str>,
334}
335
336#[derive(Deserialize)]
337pub struct OpenAiEmbeddingResponse {
338    pub data: Vec<OpenAiEmbedding>,
339}
340
341#[derive(Deserialize)]
342pub struct OpenAiEmbedding {
343    pub embedding: Vec<f32>,
344}
345
346pub fn embed<'a>(
347    client: &dyn HttpClient,
348    api_url: &str,
349    api_key: &str,
350    model: OpenAiEmbeddingModel,
351    texts: impl IntoIterator<Item = &'a str>,
352) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
353    let uri = format!("{api_url}/embeddings");
354
355    let request = OpenAiEmbeddingRequest {
356        model,
357        input: texts.into_iter().collect(),
358    };
359    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
360    let request = HttpRequest::builder()
361        .method(Method::POST)
362        .uri(uri)
363        .header("Content-Type", "application/json")
364        .header("Authorization", format!("Bearer {}", api_key))
365        .body(body)
366        .map(|request| client.send(request));
367
368    async move {
369        let mut response = request?.await?;
370        let mut body = String::new();
371        response.body_mut().read_to_string(&mut body).await?;
372
373        if response.status().is_success() {
374            let response: OpenAiEmbeddingResponse =
375                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
376            Ok(response)
377        } else {
378            Err(anyhow!(
379                "error during embedding, status: {:?}, body: {:?}",
380                response.status(),
381                body
382            ))
383        }
384    }
385}
386
387pub fn extract_text_from_events(
388    response: impl Stream<Item = Result<ResponseStreamEvent>>,
389) -> impl Stream<Item = Result<String>> {
390    response.filter_map(|response| async move {
391        match response {
392            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
393            Err(error) => Some(Err(error)),
394        }
395    })
396}