open_ai.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use isahc::config::Configurable;
  4use serde::{Deserialize, Serialize};
  5use serde_json::{Map, Value};
  6use std::time::Duration;
  7use std::{convert::TryFrom, future::Future};
  8use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  9
 10pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 11
 12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 13#[serde(rename_all = "lowercase")]
 14pub enum Role {
 15    User,
 16    Assistant,
 17    System,
 18    Tool,
 19}
 20
 21impl TryFrom<String> for Role {
 22    type Error = anyhow::Error;
 23
 24    fn try_from(value: String) -> Result<Self> {
 25        match value.as_str() {
 26            "user" => Ok(Self::User),
 27            "assistant" => Ok(Self::Assistant),
 28            "system" => Ok(Self::System),
 29            "tool" => Ok(Self::Tool),
 30            _ => Err(anyhow!("invalid role '{value}'")),
 31        }
 32    }
 33}
 34
 35impl From<Role> for String {
 36    fn from(val: Role) -> Self {
 37        match val {
 38            Role::User => "user".to_owned(),
 39            Role::Assistant => "assistant".to_owned(),
 40            Role::System => "system".to_owned(),
 41            Role::Tool => "tool".to_owned(),
 42        }
 43    }
 44}
 45
 46#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 47#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 48pub enum Model {
 49    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 50    ThreePointFiveTurbo,
 51    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 52    Four,
 53    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 54    #[default]
 55    FourTurbo,
 56}
 57
 58impl Model {
 59    pub fn from_id(id: &str) -> Result<Self> {
 60        match id {
 61            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 62            "gpt-4" => Ok(Self::Four),
 63            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 64            _ => Err(anyhow!("invalid model id")),
 65        }
 66    }
 67
 68    pub fn id(&self) -> &'static str {
 69        match self {
 70            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 71            Self::Four => "gpt-4",
 72            Self::FourTurbo => "gpt-4-turbo-preview",
 73        }
 74    }
 75
 76    pub fn display_name(&self) -> &'static str {
 77        match self {
 78            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 79            Self::Four => "gpt-4",
 80            Self::FourTurbo => "gpt-4-turbo",
 81        }
 82    }
 83
 84    pub fn max_token_count(&self) -> usize {
 85        match self {
 86            Model::ThreePointFiveTurbo => 4096,
 87            Model::Four => 8192,
 88            Model::FourTurbo => 128000,
 89        }
 90    }
 91}
 92
 93#[derive(Debug, Serialize)]
 94pub struct Request {
 95    pub model: Model,
 96    pub messages: Vec<RequestMessage>,
 97    pub stream: bool,
 98    pub stop: Vec<String>,
 99    pub temperature: f32,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub tool_choice: Option<String>,
102    #[serde(skip_serializing_if = "Vec::is_empty")]
103    pub tools: Vec<ToolDefinition>,
104}
105
106#[derive(Debug, Serialize)]
107pub struct FunctionDefinition {
108    pub name: String,
109    pub description: Option<String>,
110    pub parameters: Option<Map<String, Value>>,
111}
112
113#[derive(Serialize, Debug)]
114#[serde(tag = "type", rename_all = "snake_case")]
115pub enum ToolDefinition {
116    #[allow(dead_code)]
117    Function { function: FunctionDefinition },
118}
119
120#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
121#[serde(tag = "role", rename_all = "lowercase")]
122pub enum RequestMessage {
123    Assistant {
124        content: Option<String>,
125        #[serde(default, skip_serializing_if = "Vec::is_empty")]
126        tool_calls: Vec<ToolCall>,
127    },
128    User {
129        content: String,
130    },
131    System {
132        content: String,
133    },
134    Tool {
135        content: String,
136        tool_call_id: String,
137    },
138}
139
140#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
141pub struct ToolCall {
142    pub id: String,
143    #[serde(flatten)]
144    pub content: ToolCallContent,
145}
146
147#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
148#[serde(tag = "type", rename_all = "lowercase")]
149pub enum ToolCallContent {
150    Function { function: FunctionContent },
151}
152
153#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
154pub struct FunctionContent {
155    pub name: String,
156    pub arguments: String,
157}
158
159#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
160pub struct ResponseMessageDelta {
161    pub role: Option<Role>,
162    pub content: Option<String>,
163    #[serde(default, skip_serializing_if = "Vec::is_empty")]
164    pub tool_calls: Vec<ToolCallChunk>,
165}
166
167#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
168pub struct ToolCallChunk {
169    pub index: usize,
170    pub id: Option<String>,
171
172    // There is also an optional `type` field that would determine if a
173    // function is there. Sometimes this streams in with the `function` before
174    // it streams in the `type`
175    pub function: Option<FunctionChunk>,
176}
177
178#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
179pub struct FunctionChunk {
180    pub name: Option<String>,
181    pub arguments: Option<String>,
182}
183
184#[derive(Deserialize, Debug)]
185pub struct Usage {
186    pub prompt_tokens: u32,
187    pub completion_tokens: u32,
188    pub total_tokens: u32,
189}
190
191#[derive(Deserialize, Debug)]
192pub struct ChoiceDelta {
193    pub index: u32,
194    pub delta: ResponseMessageDelta,
195    pub finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug)]
199pub struct ResponseStreamEvent {
200    pub created: u32,
201    pub model: String,
202    pub choices: Vec<ChoiceDelta>,
203    pub usage: Option<Usage>,
204}
205
206pub async fn stream_completion(
207    client: &dyn HttpClient,
208    api_url: &str,
209    api_key: &str,
210    request: Request,
211    low_speed_timeout: Option<Duration>,
212) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
213    let uri = format!("{api_url}/chat/completions");
214    let mut request_builder = HttpRequest::builder()
215        .method(Method::POST)
216        .uri(uri)
217        .header("Content-Type", "application/json")
218        .header("Authorization", format!("Bearer {}", api_key));
219
220    if let Some(low_speed_timeout) = low_speed_timeout {
221        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
222    };
223
224    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
225    let mut response = client.send(request).await?;
226    if response.status().is_success() {
227        let reader = BufReader::new(response.into_body());
228        Ok(reader
229            .lines()
230            .filter_map(|line| async move {
231                match line {
232                    Ok(line) => {
233                        let line = line.strip_prefix("data: ")?;
234                        if line == "[DONE]" {
235                            None
236                        } else {
237                            match serde_json::from_str(line) {
238                                Ok(response) => Some(Ok(response)),
239                                Err(error) => Some(Err(anyhow!(error))),
240                            }
241                        }
242                    }
243                    Err(error) => Some(Err(anyhow!(error))),
244                }
245            })
246            .boxed())
247    } else {
248        let mut body = String::new();
249        response.body_mut().read_to_string(&mut body).await?;
250
251        #[derive(Deserialize)]
252        struct OpenAiResponse {
253            error: OpenAiError,
254        }
255
256        #[derive(Deserialize)]
257        struct OpenAiError {
258            message: String,
259        }
260
261        match serde_json::from_str::<OpenAiResponse>(&body) {
262            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
263                "Failed to connect to OpenAI API: {}",
264                response.error.message,
265            )),
266
267            _ => Err(anyhow!(
268                "Failed to connect to OpenAI API: {} {}",
269                response.status(),
270                body,
271            )),
272        }
273    }
274}
275
276#[derive(Copy, Clone, Serialize, Deserialize)]
277pub enum OpenAiEmbeddingModel {
278    #[serde(rename = "text-embedding-3-small")]
279    TextEmbedding3Small,
280    #[serde(rename = "text-embedding-3-large")]
281    TextEmbedding3Large,
282}
283
284#[derive(Serialize)]
285struct OpenAiEmbeddingRequest<'a> {
286    model: OpenAiEmbeddingModel,
287    input: Vec<&'a str>,
288}
289
290#[derive(Deserialize)]
291pub struct OpenAiEmbeddingResponse {
292    pub data: Vec<OpenAiEmbedding>,
293}
294
295#[derive(Deserialize)]
296pub struct OpenAiEmbedding {
297    pub embedding: Vec<f32>,
298}
299
300pub fn embed<'a>(
301    client: &dyn HttpClient,
302    api_url: &str,
303    api_key: &str,
304    model: OpenAiEmbeddingModel,
305    texts: impl IntoIterator<Item = &'a str>,
306) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
307    let uri = format!("{api_url}/embeddings");
308
309    let request = OpenAiEmbeddingRequest {
310        model,
311        input: texts.into_iter().collect(),
312    };
313    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
314    let request = HttpRequest::builder()
315        .method(Method::POST)
316        .uri(uri)
317        .header("Content-Type", "application/json")
318        .header("Authorization", format!("Bearer {}", api_key))
319        .body(body)
320        .map(|request| client.send(request));
321
322    async move {
323        let mut response = request?.await?;
324        let mut body = String::new();
325        response.body_mut().read_to_string(&mut body).await?;
326
327        if response.status().is_success() {
328            let response: OpenAiEmbeddingResponse =
329                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
330            Ok(response)
331        } else {
332            Err(anyhow!(
333                "error during embedding, status: {:?}, body: {:?}",
334                response.status(),
335                body
336            ))
337        }
338    }
339}