open_ai.rs

  1use anyhow::{anyhow, Context, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use serde::{Deserialize, Serialize};
  4use std::{convert::TryFrom, future::Future};
  5use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6
  7pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
  8
  9#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 10#[serde(rename_all = "lowercase")]
 11pub enum Role {
 12    User,
 13    Assistant,
 14    System,
 15}
 16
 17impl TryFrom<String> for Role {
 18    type Error = anyhow::Error;
 19
 20    fn try_from(value: String) -> Result<Self> {
 21        match value.as_str() {
 22            "user" => Ok(Self::User),
 23            "assistant" => Ok(Self::Assistant),
 24            "system" => Ok(Self::System),
 25            _ => Err(anyhow!("invalid role '{value}'")),
 26        }
 27    }
 28}
 29
 30impl From<Role> for String {
 31    fn from(val: Role) -> Self {
 32        match val {
 33            Role::User => "user".to_owned(),
 34            Role::Assistant => "assistant".to_owned(),
 35            Role::System => "system".to_owned(),
 36        }
 37    }
 38}
 39
 40#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 41#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 42pub enum Model {
 43    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
 44    ThreePointFiveTurbo,
 45    #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
 46    Four,
 47    #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
 48    #[default]
 49    FourTurbo,
 50}
 51
 52impl Model {
 53    pub fn from_id(id: &str) -> Result<Self> {
 54        match id {
 55            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 56            "gpt-4" => Ok(Self::Four),
 57            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 58            _ => Err(anyhow!("invalid model id")),
 59        }
 60    }
 61
 62    pub fn id(&self) -> &'static str {
 63        match self {
 64            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 65            Self::Four => "gpt-4",
 66            Self::FourTurbo => "gpt-4-turbo-preview",
 67        }
 68    }
 69
 70    pub fn display_name(&self) -> &'static str {
 71        match self {
 72            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
 73            Self::Four => "gpt-4",
 74            Self::FourTurbo => "gpt-4-turbo",
 75        }
 76    }
 77
 78    pub fn max_token_count(&self) -> usize {
 79        match self {
 80            Model::ThreePointFiveTurbo => 4096,
 81            Model::Four => 8192,
 82            Model::FourTurbo => 128000,
 83        }
 84    }
 85}
 86
 87#[derive(Debug, Serialize)]
 88pub struct Request {
 89    pub model: Model,
 90    pub messages: Vec<RequestMessage>,
 91    pub stream: bool,
 92    pub stop: Vec<String>,
 93    pub temperature: f32,
 94}
 95
 96#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 97pub struct RequestMessage {
 98    pub role: Role,
 99    pub content: String,
100}
101
102#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
103pub struct ResponseMessage {
104    pub role: Option<Role>,
105    pub content: Option<String>,
106}
107
108#[derive(Deserialize, Debug)]
109pub struct Usage {
110    pub prompt_tokens: u32,
111    pub completion_tokens: u32,
112    pub total_tokens: u32,
113}
114
115#[derive(Deserialize, Debug)]
116pub struct ChoiceDelta {
117    pub index: u32,
118    pub delta: ResponseMessage,
119    pub finish_reason: Option<String>,
120}
121
122#[derive(Deserialize, Debug)]
123pub struct ResponseStreamEvent {
124    pub created: u32,
125    pub model: String,
126    pub choices: Vec<ChoiceDelta>,
127    pub usage: Option<Usage>,
128}
129
130pub async fn stream_completion(
131    client: &dyn HttpClient,
132    api_url: &str,
133    api_key: &str,
134    request: Request,
135) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
136    let uri = format!("{api_url}/chat/completions");
137    let request = HttpRequest::builder()
138        .method(Method::POST)
139        .uri(uri)
140        .header("Content-Type", "application/json")
141        .header("Authorization", format!("Bearer {}", api_key))
142        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
143    let mut response = client.send(request).await?;
144    if response.status().is_success() {
145        let reader = BufReader::new(response.into_body());
146        Ok(reader
147            .lines()
148            .filter_map(|line| async move {
149                match line {
150                    Ok(line) => {
151                        let line = line.strip_prefix("data: ")?;
152                        if line == "[DONE]" {
153                            None
154                        } else {
155                            match serde_json::from_str(line) {
156                                Ok(response) => Some(Ok(response)),
157                                Err(error) => Some(Err(anyhow!(error))),
158                            }
159                        }
160                    }
161                    Err(error) => Some(Err(anyhow!(error))),
162                }
163            })
164            .boxed())
165    } else {
166        let mut body = String::new();
167        response.body_mut().read_to_string(&mut body).await?;
168
169        #[derive(Deserialize)]
170        struct OpenAiResponse {
171            error: OpenAiError,
172        }
173
174        #[derive(Deserialize)]
175        struct OpenAiError {
176            message: String,
177        }
178
179        match serde_json::from_str::<OpenAiResponse>(&body) {
180            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
181                "Failed to connect to OpenAI API: {}",
182                response.error.message,
183            )),
184
185            _ => Err(anyhow!(
186                "Failed to connect to OpenAI API: {} {}",
187                response.status(),
188                body,
189            )),
190        }
191    }
192}
193
194#[derive(Copy, Clone, Serialize, Deserialize)]
195pub enum OpenAiEmbeddingModel {
196    #[serde(rename = "text-embedding-3-small")]
197    TextEmbedding3Small,
198    #[serde(rename = "text-embedding-3-large")]
199    TextEmbedding3Large,
200}
201
202#[derive(Serialize)]
203struct OpenAiEmbeddingRequest<'a> {
204    model: OpenAiEmbeddingModel,
205    input: Vec<&'a str>,
206}
207
208#[derive(Deserialize)]
209pub struct OpenAiEmbeddingResponse {
210    pub data: Vec<OpenAiEmbedding>,
211}
212
213#[derive(Deserialize)]
214pub struct OpenAiEmbedding {
215    pub embedding: Vec<f32>,
216}
217
218pub fn embed<'a>(
219    client: &dyn HttpClient,
220    api_url: &str,
221    api_key: &str,
222    model: OpenAiEmbeddingModel,
223    texts: impl IntoIterator<Item = &'a str>,
224) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
225    let uri = format!("{api_url}/embeddings");
226
227    let request = OpenAiEmbeddingRequest {
228        model,
229        input: texts.into_iter().collect(),
230    };
231    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
232    let request = HttpRequest::builder()
233        .method(Method::POST)
234        .uri(uri)
235        .header("Content-Type", "application/json")
236        .header("Authorization", format!("Bearer {}", api_key))
237        .body(body)
238        .map(|request| client.send(request));
239
240    async move {
241        let mut response = request?.await?;
242        let mut body = String::new();
243        response.body_mut().read_to_string(&mut body).await?;
244
245        if response.status().is_success() {
246            let response: OpenAiEmbeddingResponse =
247                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
248            Ok(response)
249        } else {
250            Err(anyhow!(
251                "error during embedding, status: {:?}, body: {:?}",
252                response.status(),
253                body
254            ))
255        }
256    }
257}