open_ai.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use serde::{Deserialize, Serialize};
  6use serde_json::{Map, Value};
  7use std::{convert::TryFrom, future::Future, time::Duration};
  8use strum::EnumIter;
  9
 10pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 11
 12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 13#[serde(rename_all = "lowercase")]
 14pub enum Role {
 15    User,
 16    Assistant,
 17    System,
 18    Tool,
 19}
 20
 21impl TryFrom<String> for Role {
 22    type Error = anyhow::Error;
 23
 24    fn try_from(value: String) -> Result<Self> {
 25        match value.as_str() {
 26            "user" => Ok(Self::User),
 27            "assistant" => Ok(Self::Assistant),
 28            "system" => Ok(Self::System),
 29            "tool" => Ok(Self::Tool),
 30            _ => Err(anyhow!("invalid role '{value}'")),
 31        }
 32    }
 33}
 34
 35impl From<Role> for String {
 36    fn from(val: Role) -> Self {
 37        match val {
 38            Role::User => "user".to_owned(),
 39            Role::Assistant => "assistant".to_owned(),
 40            Role::System => "system".to_owned(),
 41            Role::Tool => "tool".to_owned(),
 42        }
 43    }
 44}
 45
 46#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 47#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 48pub enum Model {
 49    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 50    ThreePointFiveTurbo,
 51    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 52    Four,
 53    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 54    FourTurbo,
 55    #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
 56    #[default]
 57    FourOmni,
 58    #[serde(rename = "custom")]
 59    Custom { name: String, max_tokens: usize },
 60}
 61
 62impl Model {
 63    pub fn from_id(id: &str) -> Result<Self> {
 64        match id {
 65            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 66            "gpt-4" => Ok(Self::Four),
 67            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 68            "gpt-4o" => Ok(Self::FourOmni),
 69            _ => Err(anyhow!("invalid model id")),
 70        }
 71    }
 72
 73    pub fn id(&self) -> &'static str {
 74        match self {
 75            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 76            Self::Four => "gpt-4",
 77            Self::FourTurbo => "gpt-4-turbo-preview",
 78            Self::FourOmni => "gpt-4o",
 79            Self::Custom { .. } => "custom",
 80        }
 81    }
 82
 83    pub fn display_name(&self) -> &str {
 84        match self {
 85            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 86            Self::Four => "gpt-4",
 87            Self::FourTurbo => "gpt-4-turbo",
 88            Self::FourOmni => "gpt-4o",
 89            Self::Custom { name, .. } => name,
 90        }
 91    }
 92
 93    pub fn max_token_count(&self) -> usize {
 94        match self {
 95            Model::ThreePointFiveTurbo => 4096,
 96            Model::Four => 8192,
 97            Model::FourTurbo => 128000,
 98            Model::FourOmni => 128000,
 99            Model::Custom { max_tokens, .. } => *max_tokens,
100        }
101    }
102}
103
104fn serialize_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
105where
106    S: serde::Serializer,
107{
108    match model {
109        Model::Custom { name, .. } => serializer.serialize_str(name),
110        _ => serializer.serialize_str(model.id()),
111    }
112}
113
114#[derive(Debug, Serialize)]
115pub struct Request {
116    #[serde(serialize_with = "serialize_model")]
117    pub model: Model,
118    pub messages: Vec<RequestMessage>,
119    pub stream: bool,
120    pub stop: Vec<String>,
121    pub temperature: f32,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub tool_choice: Option<String>,
124    #[serde(skip_serializing_if = "Vec::is_empty")]
125    pub tools: Vec<ToolDefinition>,
126}
127
128#[derive(Debug, Serialize)]
129pub struct FunctionDefinition {
130    pub name: String,
131    pub description: Option<String>,
132    pub parameters: Option<Map<String, Value>>,
133}
134
135#[derive(Serialize, Debug)]
136#[serde(tag = "type", rename_all = "snake_case")]
137pub enum ToolDefinition {
138    #[allow(dead_code)]
139    Function { function: FunctionDefinition },
140}
141
142#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
143#[serde(tag = "role", rename_all = "lowercase")]
144pub enum RequestMessage {
145    Assistant {
146        content: Option<String>,
147        #[serde(default, skip_serializing_if = "Vec::is_empty")]
148        tool_calls: Vec<ToolCall>,
149    },
150    User {
151        content: String,
152    },
153    System {
154        content: String,
155    },
156    Tool {
157        content: String,
158        tool_call_id: String,
159    },
160}
161
162#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
163pub struct ToolCall {
164    pub id: String,
165    #[serde(flatten)]
166    pub content: ToolCallContent,
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170#[serde(tag = "type", rename_all = "lowercase")]
171pub enum ToolCallContent {
172    Function { function: FunctionContent },
173}
174
175#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
176pub struct FunctionContent {
177    pub name: String,
178    pub arguments: String,
179}
180
181#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
182pub struct ResponseMessageDelta {
183    pub role: Option<Role>,
184    pub content: Option<String>,
185    #[serde(default, skip_serializing_if = "Vec::is_empty")]
186    pub tool_calls: Vec<ToolCallChunk>,
187}
188
189#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
190pub struct ToolCallChunk {
191    pub index: usize,
192    pub id: Option<String>,
193
194    // There is also an optional `type` field that would determine if a
195    // function is there. Sometimes this streams in with the `function` before
196    // it streams in the `type`
197    pub function: Option<FunctionChunk>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
201pub struct FunctionChunk {
202    pub name: Option<String>,
203    pub arguments: Option<String>,
204}
205
206#[derive(Deserialize, Debug)]
207pub struct Usage {
208    pub prompt_tokens: u32,
209    pub completion_tokens: u32,
210    pub total_tokens: u32,
211}
212
213#[derive(Deserialize, Debug)]
214pub struct ChoiceDelta {
215    pub index: u32,
216    pub delta: ResponseMessageDelta,
217    pub finish_reason: Option<String>,
218}
219
220#[derive(Deserialize, Debug)]
221pub struct ResponseStreamEvent {
222    pub created: u32,
223    pub model: String,
224    pub choices: Vec<ChoiceDelta>,
225    pub usage: Option<Usage>,
226}
227
228pub async fn stream_completion(
229    client: &dyn HttpClient,
230    api_url: &str,
231    api_key: &str,
232    request: Request,
233    low_speed_timeout: Option<Duration>,
234) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
235    let uri = format!("{api_url}/chat/completions");
236    let mut request_builder = HttpRequest::builder()
237        .method(Method::POST)
238        .uri(uri)
239        .header("Content-Type", "application/json")
240        .header("Authorization", format!("Bearer {}", api_key));
241
242    if let Some(low_speed_timeout) = low_speed_timeout {
243        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
244    };
245
246    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
247    let mut response = client.send(request).await?;
248    if response.status().is_success() {
249        let reader = BufReader::new(response.into_body());
250        Ok(reader
251            .lines()
252            .filter_map(|line| async move {
253                match line {
254                    Ok(line) => {
255                        let line = line.strip_prefix("data: ")?;
256                        if line == "[DONE]" {
257                            None
258                        } else {
259                            match serde_json::from_str(line) {
260                                Ok(response) => Some(Ok(response)),
261                                Err(error) => Some(Err(anyhow!(error))),
262                            }
263                        }
264                    }
265                    Err(error) => Some(Err(anyhow!(error))),
266                }
267            })
268            .boxed())
269    } else {
270        let mut body = String::new();
271        response.body_mut().read_to_string(&mut body).await?;
272
273        #[derive(Deserialize)]
274        struct OpenAiResponse {
275            error: OpenAiError,
276        }
277
278        #[derive(Deserialize)]
279        struct OpenAiError {
280            message: String,
281        }
282
283        match serde_json::from_str::<OpenAiResponse>(&body) {
284            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
285                "Failed to connect to OpenAI API: {}",
286                response.error.message,
287            )),
288
289            _ => Err(anyhow!(
290                "Failed to connect to OpenAI API: {} {}",
291                response.status(),
292                body,
293            )),
294        }
295    }
296}
297
298#[derive(Copy, Clone, Serialize, Deserialize)]
299pub enum OpenAiEmbeddingModel {
300    #[serde(rename = "text-embedding-3-small")]
301    TextEmbedding3Small,
302    #[serde(rename = "text-embedding-3-large")]
303    TextEmbedding3Large,
304}
305
306#[derive(Serialize)]
307struct OpenAiEmbeddingRequest<'a> {
308    model: OpenAiEmbeddingModel,
309    input: Vec<&'a str>,
310}
311
312#[derive(Deserialize)]
313pub struct OpenAiEmbeddingResponse {
314    pub data: Vec<OpenAiEmbedding>,
315}
316
317#[derive(Deserialize)]
318pub struct OpenAiEmbedding {
319    pub embedding: Vec<f32>,
320}
321
322pub fn embed<'a>(
323    client: &dyn HttpClient,
324    api_url: &str,
325    api_key: &str,
326    model: OpenAiEmbeddingModel,
327    texts: impl IntoIterator<Item = &'a str>,
328) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
329    let uri = format!("{api_url}/embeddings");
330
331    let request = OpenAiEmbeddingRequest {
332        model,
333        input: texts.into_iter().collect(),
334    };
335    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
336    let request = HttpRequest::builder()
337        .method(Method::POST)
338        .uri(uri)
339        .header("Content-Type", "application/json")
340        .header("Authorization", format!("Bearer {}", api_key))
341        .body(body)
342        .map(|request| client.send(request));
343
344    async move {
345        let mut response = request?.await?;
346        let mut body = String::new();
347        response.body_mut().read_to_string(&mut body).await?;
348
349        if response.status().is_success() {
350            let response: OpenAiEmbeddingResponse =
351                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
352            Ok(response)
353        } else {
354            Err(anyhow!(
355                "error during embedding, status: {:?}, body: {:?}",
356                response.status(),
357                body
358            ))
359        }
360    }
361}