open_ai.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use serde::{Deserialize, Serialize};
  6use serde_json::{Map, Value};
  7use std::{convert::TryFrom, future::Future, time::Duration};
  8use strum::EnumIter;
  9
 10pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 11
 12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 13#[serde(rename_all = "lowercase")]
 14pub enum Role {
 15    User,
 16    Assistant,
 17    System,
 18    Tool,
 19}
 20
 21impl TryFrom<String> for Role {
 22    type Error = anyhow::Error;
 23
 24    fn try_from(value: String) -> Result<Self> {
 25        match value.as_str() {
 26            "user" => Ok(Self::User),
 27            "assistant" => Ok(Self::Assistant),
 28            "system" => Ok(Self::System),
 29            "tool" => Ok(Self::Tool),
 30            _ => Err(anyhow!("invalid role '{value}'")),
 31        }
 32    }
 33}
 34
 35impl From<Role> for String {
 36    fn from(val: Role) -> Self {
 37        match val {
 38            Role::User => "user".to_owned(),
 39            Role::Assistant => "assistant".to_owned(),
 40            Role::System => "system".to_owned(),
 41            Role::Tool => "tool".to_owned(),
 42        }
 43    }
 44}
 45
 46#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 47#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 48pub enum Model {
 49    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 50    ThreePointFiveTurbo,
 51    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 52    Four,
 53    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 54    FourTurbo,
 55    #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
 56    #[default]
 57    FourOmni,
 58}
 59
 60impl Model {
 61    pub fn from_id(id: &str) -> Result<Self> {
 62        match id {
 63            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 64            "gpt-4" => Ok(Self::Four),
 65            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 66            "gpt-4o" => Ok(Self::FourOmni),
 67            _ => Err(anyhow!("invalid model id")),
 68        }
 69    }
 70
 71    pub fn id(&self) -> &'static str {
 72        match self {
 73            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 74            Self::Four => "gpt-4",
 75            Self::FourTurbo => "gpt-4-turbo-preview",
 76            Self::FourOmni => "gpt-4o",
 77        }
 78    }
 79
 80    pub fn display_name(&self) -> &'static str {
 81        match self {
 82            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 83            Self::Four => "gpt-4",
 84            Self::FourTurbo => "gpt-4-turbo",
 85            Self::FourOmni => "gpt-4o",
 86        }
 87    }
 88
 89    pub fn max_token_count(&self) -> usize {
 90        match self {
 91            Model::ThreePointFiveTurbo => 4096,
 92            Model::Four => 8192,
 93            Model::FourTurbo => 128000,
 94            Model::FourOmni => 128000,
 95        }
 96    }
 97}
 98
 99#[derive(Debug, Serialize)]
100pub struct Request {
101    pub model: Model,
102    pub messages: Vec<RequestMessage>,
103    pub stream: bool,
104    pub stop: Vec<String>,
105    pub temperature: f32,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub tool_choice: Option<String>,
108    #[serde(skip_serializing_if = "Vec::is_empty")]
109    pub tools: Vec<ToolDefinition>,
110}
111
112#[derive(Debug, Serialize)]
113pub struct FunctionDefinition {
114    pub name: String,
115    pub description: Option<String>,
116    pub parameters: Option<Map<String, Value>>,
117}
118
119#[derive(Serialize, Debug)]
120#[serde(tag = "type", rename_all = "snake_case")]
121pub enum ToolDefinition {
122    #[allow(dead_code)]
123    Function { function: FunctionDefinition },
124}
125
126#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
127#[serde(tag = "role", rename_all = "lowercase")]
128pub enum RequestMessage {
129    Assistant {
130        content: Option<String>,
131        #[serde(default, skip_serializing_if = "Vec::is_empty")]
132        tool_calls: Vec<ToolCall>,
133    },
134    User {
135        content: String,
136    },
137    System {
138        content: String,
139    },
140    Tool {
141        content: String,
142        tool_call_id: String,
143    },
144}
145
146#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
147pub struct ToolCall {
148    pub id: String,
149    #[serde(flatten)]
150    pub content: ToolCallContent,
151}
152
153#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
154#[serde(tag = "type", rename_all = "lowercase")]
155pub enum ToolCallContent {
156    Function { function: FunctionContent },
157}
158
159#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
160pub struct FunctionContent {
161    pub name: String,
162    pub arguments: String,
163}
164
165#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
166pub struct ResponseMessageDelta {
167    pub role: Option<Role>,
168    pub content: Option<String>,
169    #[serde(default, skip_serializing_if = "Vec::is_empty")]
170    pub tool_calls: Vec<ToolCallChunk>,
171}
172
173#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
174pub struct ToolCallChunk {
175    pub index: usize,
176    pub id: Option<String>,
177
178    // There is also an optional `type` field that would determine if a
179    // function is there. Sometimes this streams in with the `function` before
180    // it streams in the `type`
181    pub function: Option<FunctionChunk>,
182}
183
184#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
185pub struct FunctionChunk {
186    pub name: Option<String>,
187    pub arguments: Option<String>,
188}
189
190#[derive(Deserialize, Debug)]
191pub struct Usage {
192    pub prompt_tokens: u32,
193    pub completion_tokens: u32,
194    pub total_tokens: u32,
195}
196
197#[derive(Deserialize, Debug)]
198pub struct ChoiceDelta {
199    pub index: u32,
200    pub delta: ResponseMessageDelta,
201    pub finish_reason: Option<String>,
202}
203
204#[derive(Deserialize, Debug)]
205pub struct ResponseStreamEvent {
206    pub created: u32,
207    pub model: String,
208    pub choices: Vec<ChoiceDelta>,
209    pub usage: Option<Usage>,
210}
211
212pub async fn stream_completion(
213    client: &dyn HttpClient,
214    api_url: &str,
215    api_key: &str,
216    request: Request,
217    low_speed_timeout: Option<Duration>,
218) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
219    let uri = format!("{api_url}/chat/completions");
220    let mut request_builder = HttpRequest::builder()
221        .method(Method::POST)
222        .uri(uri)
223        .header("Content-Type", "application/json")
224        .header("Authorization", format!("Bearer {}", api_key));
225
226    if let Some(low_speed_timeout) = low_speed_timeout {
227        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
228    };
229
230    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
231    let mut response = client.send(request).await?;
232    if response.status().is_success() {
233        let reader = BufReader::new(response.into_body());
234        Ok(reader
235            .lines()
236            .filter_map(|line| async move {
237                match line {
238                    Ok(line) => {
239                        let line = line.strip_prefix("data: ")?;
240                        if line == "[DONE]" {
241                            None
242                        } else {
243                            match serde_json::from_str(line) {
244                                Ok(response) => Some(Ok(response)),
245                                Err(error) => Some(Err(anyhow!(error))),
246                            }
247                        }
248                    }
249                    Err(error) => Some(Err(anyhow!(error))),
250                }
251            })
252            .boxed())
253    } else {
254        let mut body = String::new();
255        response.body_mut().read_to_string(&mut body).await?;
256
257        #[derive(Deserialize)]
258        struct OpenAiResponse {
259            error: OpenAiError,
260        }
261
262        #[derive(Deserialize)]
263        struct OpenAiError {
264            message: String,
265        }
266
267        match serde_json::from_str::<OpenAiResponse>(&body) {
268            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
269                "Failed to connect to OpenAI API: {}",
270                response.error.message,
271            )),
272
273            _ => Err(anyhow!(
274                "Failed to connect to OpenAI API: {} {}",
275                response.status(),
276                body,
277            )),
278        }
279    }
280}
281
282#[derive(Copy, Clone, Serialize, Deserialize)]
283pub enum OpenAiEmbeddingModel {
284    #[serde(rename = "text-embedding-3-small")]
285    TextEmbedding3Small,
286    #[serde(rename = "text-embedding-3-large")]
287    TextEmbedding3Large,
288}
289
290#[derive(Serialize)]
291struct OpenAiEmbeddingRequest<'a> {
292    model: OpenAiEmbeddingModel,
293    input: Vec<&'a str>,
294}
295
296#[derive(Deserialize)]
297pub struct OpenAiEmbeddingResponse {
298    pub data: Vec<OpenAiEmbedding>,
299}
300
301#[derive(Deserialize)]
302pub struct OpenAiEmbedding {
303    pub embedding: Vec<f32>,
304}
305
306pub fn embed<'a>(
307    client: &dyn HttpClient,
308    api_url: &str,
309    api_key: &str,
310    model: OpenAiEmbeddingModel,
311    texts: impl IntoIterator<Item = &'a str>,
312) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
313    let uri = format!("{api_url}/embeddings");
314
315    let request = OpenAiEmbeddingRequest {
316        model,
317        input: texts.into_iter().collect(),
318    };
319    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
320    let request = HttpRequest::builder()
321        .method(Method::POST)
322        .uri(uri)
323        .header("Content-Type", "application/json")
324        .header("Authorization", format!("Bearer {}", api_key))
325        .body(body)
326        .map(|request| client.send(request));
327
328    async move {
329        let mut response = request?.await?;
330        let mut body = String::new();
331        response.body_mut().read_to_string(&mut body).await?;
332
333        if response.status().is_success() {
334            let response: OpenAiEmbeddingResponse =
335                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
336            Ok(response)
337        } else {
338            Err(anyhow!(
339                "error during embedding, status: {:?}, body: {:?}",
340                response.status(),
341                body
342            ))
343        }
344    }
345}