1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::{convert::TryFrom, future::Future};
5use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6
7pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
8
9#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
10#[serde(rename_all = "lowercase")]
11pub enum Role {
12 User,
13 Assistant,
14 System,
15}
16
17impl TryFrom<String> for Role {
18 type Error = anyhow::Error;
19
20 fn try_from(value: String) -> Result<Self> {
21 match value.as_str() {
22 "user" => Ok(Self::User),
23 "assistant" => Ok(Self::Assistant),
24 "system" => Ok(Self::System),
25 _ => Err(anyhow!("invalid role '{value}'")),
26 }
27 }
28}
29
30impl From<Role> for String {
31 fn from(val: Role) -> Self {
32 match val {
33 Role::User => "user".to_owned(),
34 Role::Assistant => "assistant".to_owned(),
35 Role::System => "system".to_owned(),
36 }
37 }
38}
39
40#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
41#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
42pub enum Model {
43 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
44 ThreePointFiveTurbo,
45 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
46 Four,
47 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
48 #[default]
49 FourTurbo,
50}
51
52impl Model {
53 pub fn from_id(id: &str) -> Result<Self> {
54 match id {
55 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
56 "gpt-4" => Ok(Self::Four),
57 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
58 _ => Err(anyhow!("invalid model id")),
59 }
60 }
61
62 pub fn id(&self) -> &'static str {
63 match self {
64 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
65 Self::Four => "gpt-4",
66 Self::FourTurbo => "gpt-4-turbo-preview",
67 }
68 }
69
70 pub fn display_name(&self) -> &'static str {
71 match self {
72 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
73 Self::Four => "gpt-4",
74 Self::FourTurbo => "gpt-4-turbo",
75 }
76 }
77
78 pub fn max_token_count(&self) -> usize {
79 match self {
80 Model::ThreePointFiveTurbo => 4096,
81 Model::Four => 8192,
82 Model::FourTurbo => 128000,
83 }
84 }
85}
86
87#[derive(Debug, Serialize)]
88pub struct Request {
89 pub model: Model,
90 pub messages: Vec<RequestMessage>,
91 pub stream: bool,
92 pub stop: Vec<String>,
93 pub temperature: f32,
94}
95
96#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
97pub struct RequestMessage {
98 pub role: Role,
99 pub content: String,
100}
101
102#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
103pub struct ResponseMessage {
104 pub role: Option<Role>,
105 pub content: Option<String>,
106}
107
108#[derive(Deserialize, Debug)]
109pub struct Usage {
110 pub prompt_tokens: u32,
111 pub completion_tokens: u32,
112 pub total_tokens: u32,
113}
114
115#[derive(Deserialize, Debug)]
116pub struct ChoiceDelta {
117 pub index: u32,
118 pub delta: ResponseMessage,
119 pub finish_reason: Option<String>,
120}
121
122#[derive(Deserialize, Debug)]
123pub struct ResponseStreamEvent {
124 pub created: u32,
125 pub model: String,
126 pub choices: Vec<ChoiceDelta>,
127 pub usage: Option<Usage>,
128}
129
130pub async fn stream_completion(
131 client: &dyn HttpClient,
132 api_url: &str,
133 api_key: &str,
134 request: Request,
135) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
136 let uri = format!("{api_url}/chat/completions");
137 let request = HttpRequest::builder()
138 .method(Method::POST)
139 .uri(uri)
140 .header("Content-Type", "application/json")
141 .header("Authorization", format!("Bearer {}", api_key))
142 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
143 let mut response = client.send(request).await?;
144 if response.status().is_success() {
145 let reader = BufReader::new(response.into_body());
146 Ok(reader
147 .lines()
148 .filter_map(|line| async move {
149 match line {
150 Ok(line) => {
151 let line = line.strip_prefix("data: ")?;
152 if line == "[DONE]" {
153 None
154 } else {
155 match serde_json::from_str(line) {
156 Ok(response) => Some(Ok(response)),
157 Err(error) => Some(Err(anyhow!(error))),
158 }
159 }
160 }
161 Err(error) => Some(Err(anyhow!(error))),
162 }
163 })
164 .boxed())
165 } else {
166 let mut body = String::new();
167 response.body_mut().read_to_string(&mut body).await?;
168
169 #[derive(Deserialize)]
170 struct OpenAiResponse {
171 error: OpenAiError,
172 }
173
174 #[derive(Deserialize)]
175 struct OpenAiError {
176 message: String,
177 }
178
179 match serde_json::from_str::<OpenAiResponse>(&body) {
180 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
181 "Failed to connect to OpenAI API: {}",
182 response.error.message,
183 )),
184
185 _ => Err(anyhow!(
186 "Failed to connect to OpenAI API: {} {}",
187 response.status(),
188 body,
189 )),
190 }
191 }
192}
193
194#[derive(Copy, Clone, Serialize, Deserialize)]
195pub enum OpenAiEmbeddingModel {
196 #[serde(rename = "text-embedding-3-small")]
197 TextEmbedding3Small,
198 #[serde(rename = "text-embedding-3-large")]
199 TextEmbedding3Large,
200}
201
202#[derive(Serialize)]
203struct OpenAiEmbeddingRequest<'a> {
204 model: OpenAiEmbeddingModel,
205 input: Vec<&'a str>,
206}
207
208#[derive(Deserialize)]
209pub struct OpenAiEmbeddingResponse {
210 pub data: Vec<OpenAiEmbedding>,
211}
212
213#[derive(Deserialize)]
214pub struct OpenAiEmbedding {
215 pub embedding: Vec<f32>,
216}
217
218pub fn embed<'a>(
219 client: &dyn HttpClient,
220 api_url: &str,
221 api_key: &str,
222 model: OpenAiEmbeddingModel,
223 texts: impl IntoIterator<Item = &'a str>,
224) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
225 let uri = format!("{api_url}/embeddings");
226
227 let request = OpenAiEmbeddingRequest {
228 model,
229 input: texts.into_iter().collect(),
230 };
231 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
232 let request = HttpRequest::builder()
233 .method(Method::POST)
234 .uri(uri)
235 .header("Content-Type", "application/json")
236 .header("Authorization", format!("Bearer {}", api_key))
237 .body(body)
238 .map(|request| client.send(request));
239
240 async move {
241 let mut response = request?.await?;
242 let mut body = String::new();
243 response.body_mut().read_to_string(&mut body).await?;
244
245 if response.status().is_success() {
246 let response: OpenAiEmbeddingResponse =
247 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
248 Ok(response)
249 } else {
250 Err(anyhow!(
251 "error during embedding, status: {:?}, body: {:?}",
252 response.status(),
253 body
254 ))
255 }
256 }
257}