open_ai.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use serde::{Deserialize, Serialize};
  4use serde_json::{Map, Value};
  5use std::{convert::TryFrom, future::Future};
  6use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  7
  8pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
  9
 10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 11#[serde(rename_all = "lowercase")]
 12pub enum Role {
 13    User,
 14    Assistant,
 15    System,
 16    Tool,
 17}
 18
 19impl TryFrom<String> for Role {
 20    type Error = anyhow::Error;
 21
 22    fn try_from(value: String) -> Result<Self> {
 23        match value.as_str() {
 24            "user" => Ok(Self::User),
 25            "assistant" => Ok(Self::Assistant),
 26            "system" => Ok(Self::System),
 27            "tool" => Ok(Self::Tool),
 28            _ => Err(anyhow!("invalid role '{value}'")),
 29        }
 30    }
 31}
 32
 33impl From<Role> for String {
 34    fn from(val: Role) -> Self {
 35        match val {
 36            Role::User => "user".to_owned(),
 37            Role::Assistant => "assistant".to_owned(),
 38            Role::System => "system".to_owned(),
 39            Role::Tool => "tool".to_owned(),
 40        }
 41    }
 42}
 43
 44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 46pub enum Model {
 47    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 48    ThreePointFiveTurbo,
 49    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 50    Four,
 51    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 52    #[default]
 53    FourTurbo,
 54}
 55
 56impl Model {
 57    pub fn from_id(id: &str) -> Result<Self> {
 58        match id {
 59            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 60            "gpt-4" => Ok(Self::Four),
 61            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 62            _ => Err(anyhow!("invalid model id")),
 63        }
 64    }
 65
 66    pub fn id(&self) -> &'static str {
 67        match self {
 68            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 69            Self::Four => "gpt-4",
 70            Self::FourTurbo => "gpt-4-turbo-preview",
 71        }
 72    }
 73
 74    pub fn display_name(&self) -> &'static str {
 75        match self {
 76            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 77            Self::Four => "gpt-4",
 78            Self::FourTurbo => "gpt-4-turbo",
 79        }
 80    }
 81
 82    pub fn max_token_count(&self) -> usize {
 83        match self {
 84            Model::ThreePointFiveTurbo => 4096,
 85            Model::Four => 8192,
 86            Model::FourTurbo => 128000,
 87        }
 88    }
 89}
 90
 91#[derive(Debug, Serialize)]
 92pub struct Request {
 93    pub model: Model,
 94    pub messages: Vec<RequestMessage>,
 95    pub stream: bool,
 96    pub stop: Vec<String>,
 97    pub temperature: f32,
 98    #[serde(skip_serializing_if = "Option::is_none")]
 99    pub tool_choice: Option<String>,
100    #[serde(skip_serializing_if = "Vec::is_empty")]
101    pub tools: Vec<ToolDefinition>,
102}
103
104#[derive(Debug, Serialize)]
105pub struct FunctionDefinition {
106    pub name: String,
107    pub description: Option<String>,
108    pub parameters: Option<Map<String, Value>>,
109}
110
111#[derive(Serialize, Debug)]
112#[serde(tag = "type", rename_all = "snake_case")]
113pub enum ToolDefinition {
114    #[allow(dead_code)]
115    Function { function: FunctionDefinition },
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119#[serde(tag = "role", rename_all = "lowercase")]
120pub enum RequestMessage {
121    Assistant {
122        content: Option<String>,
123        #[serde(default, skip_serializing_if = "Vec::is_empty")]
124        tool_calls: Vec<ToolCall>,
125    },
126    User {
127        content: String,
128    },
129    System {
130        content: String,
131    },
132    Tool {
133        content: String,
134        tool_call_id: String,
135    },
136}
137
138#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
139pub struct ToolCall {
140    pub id: String,
141    #[serde(flatten)]
142    pub content: ToolCallContent,
143}
144
145#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
146#[serde(tag = "type", rename_all = "lowercase")]
147pub enum ToolCallContent {
148    Function { function: FunctionContent },
149}
150
151#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
152pub struct FunctionContent {
153    pub name: String,
154    pub arguments: String,
155}
156
157#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
158pub struct ResponseMessageDelta {
159    pub role: Option<Role>,
160    pub content: Option<String>,
161    #[serde(default, skip_serializing_if = "Vec::is_empty")]
162    pub tool_calls: Vec<ToolCallChunk>,
163}
164
165#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
166pub struct ToolCallChunk {
167    pub index: usize,
168    pub id: Option<String>,
169
170    // There is also an optional `type` field that would determine if a
171    // function is there. Sometimes this streams in with the `function` before
172    // it streams in the `type`
173    pub function: Option<FunctionChunk>,
174}
175
176#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
177pub struct FunctionChunk {
178    pub name: Option<String>,
179    pub arguments: Option<String>,
180}
181
182#[derive(Deserialize, Debug)]
183pub struct Usage {
184    pub prompt_tokens: u32,
185    pub completion_tokens: u32,
186    pub total_tokens: u32,
187}
188
189#[derive(Deserialize, Debug)]
190pub struct ChoiceDelta {
191    pub index: u32,
192    pub delta: ResponseMessageDelta,
193    pub finish_reason: Option<String>,
194}
195
196#[derive(Deserialize, Debug)]
197pub struct ResponseStreamEvent {
198    pub created: u32,
199    pub model: String,
200    pub choices: Vec<ChoiceDelta>,
201    pub usage: Option<Usage>,
202}
203
204pub async fn stream_completion(
205    client: &dyn HttpClient,
206    api_url: &str,
207    api_key: &str,
208    request: Request,
209) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
210    let uri = format!("{api_url}/chat/completions");
211    let request = HttpRequest::builder()
212        .method(Method::POST)
213        .uri(uri)
214        .header("Content-Type", "application/json")
215        .header("Authorization", format!("Bearer {}", api_key))
216        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
217    let mut response = client.send(request).await?;
218    if response.status().is_success() {
219        let reader = BufReader::new(response.into_body());
220        Ok(reader
221            .lines()
222            .filter_map(|line| async move {
223                match line {
224                    Ok(line) => {
225                        let line = line.strip_prefix("data: ")?;
226                        if line == "[DONE]" {
227                            None
228                        } else {
229                            match serde_json::from_str(line) {
230                                Ok(response) => Some(Ok(response)),
231                                Err(error) => Some(Err(anyhow!(error))),
232                            }
233                        }
234                    }
235                    Err(error) => Some(Err(anyhow!(error))),
236                }
237            })
238            .boxed())
239    } else {
240        let mut body = String::new();
241        response.body_mut().read_to_string(&mut body).await?;
242
243        #[derive(Deserialize)]
244        struct OpenAiResponse {
245            error: OpenAiError,
246        }
247
248        #[derive(Deserialize)]
249        struct OpenAiError {
250            message: String,
251        }
252
253        match serde_json::from_str::<OpenAiResponse>(&body) {
254            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
255                "Failed to connect to OpenAI API: {}",
256                response.error.message,
257            )),
258
259            _ => Err(anyhow!(
260                "Failed to connect to OpenAI API: {} {}",
261                response.status(),
262                body,
263            )),
264        }
265    }
266}
267
268#[derive(Copy, Clone, Serialize, Deserialize)]
269pub enum OpenAiEmbeddingModel {
270    #[serde(rename = "text-embedding-3-small")]
271    TextEmbedding3Small,
272    #[serde(rename = "text-embedding-3-large")]
273    TextEmbedding3Large,
274}
275
276#[derive(Serialize)]
277struct OpenAiEmbeddingRequest<'a> {
278    model: OpenAiEmbeddingModel,
279    input: Vec<&'a str>,
280}
281
282#[derive(Deserialize)]
283pub struct OpenAiEmbeddingResponse {
284    pub data: Vec<OpenAiEmbedding>,
285}
286
287#[derive(Deserialize)]
288pub struct OpenAiEmbedding {
289    pub embedding: Vec<f32>,
290}
291
292pub fn embed<'a>(
293    client: &dyn HttpClient,
294    api_url: &str,
295    api_key: &str,
296    model: OpenAiEmbeddingModel,
297    texts: impl IntoIterator<Item = &'a str>,
298) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
299    let uri = format!("{api_url}/embeddings");
300
301    let request = OpenAiEmbeddingRequest {
302        model,
303        input: texts.into_iter().collect(),
304    };
305    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
306    let request = HttpRequest::builder()
307        .method(Method::POST)
308        .uri(uri)
309        .header("Content-Type", "application/json")
310        .header("Authorization", format!("Bearer {}", api_key))
311        .body(body)
312        .map(|request| client.send(request));
313
314    async move {
315        let mut response = request?.await?;
316        let mut body = String::new();
317        response.body_mut().read_to_string(&mut body).await?;
318
319        if response.status().is_success() {
320            let response: OpenAiEmbeddingResponse =
321                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
322            Ok(response)
323        } else {
324            Err(anyhow!(
325                "error during embedding, status: {:?}, body: {:?}",
326                response.status(),
327                body
328            ))
329        }
330    }
331}