1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use serde::{Deserialize, Serialize};
6use serde_json::{Map, Value};
7use std::{convert::TryFrom, future::Future, time::Duration};
8use strum::EnumIter;
9
10pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
11
12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15 User,
16 Assistant,
17 System,
18 Tool,
19}
20
21impl TryFrom<String> for Role {
22 type Error = anyhow::Error;
23
24 fn try_from(value: String) -> Result<Self> {
25 match value.as_str() {
26 "user" => Ok(Self::User),
27 "assistant" => Ok(Self::Assistant),
28 "system" => Ok(Self::System),
29 "tool" => Ok(Self::Tool),
30 _ => Err(anyhow!("invalid role '{value}'")),
31 }
32 }
33}
34
35impl From<Role> for String {
36 fn from(val: Role) -> Self {
37 match val {
38 Role::User => "user".to_owned(),
39 Role::Assistant => "assistant".to_owned(),
40 Role::System => "system".to_owned(),
41 Role::Tool => "tool".to_owned(),
42 }
43 }
44}
45
46#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
47#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
48pub enum Model {
49 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
50 ThreePointFiveTurbo,
51 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
52 Four,
53 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
54 FourTurbo,
55 #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
56 #[default]
57 FourOmni,
58}
59
60impl Model {
61 pub fn from_id(id: &str) -> Result<Self> {
62 match id {
63 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
64 "gpt-4" => Ok(Self::Four),
65 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
66 "gpt-4o" => Ok(Self::FourOmni),
67 _ => Err(anyhow!("invalid model id")),
68 }
69 }
70
71 pub fn id(&self) -> &'static str {
72 match self {
73 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
74 Self::Four => "gpt-4",
75 Self::FourTurbo => "gpt-4-turbo-preview",
76 Self::FourOmni => "gpt-4o",
77 }
78 }
79
80 pub fn display_name(&self) -> &'static str {
81 match self {
82 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
83 Self::Four => "gpt-4",
84 Self::FourTurbo => "gpt-4-turbo",
85 Self::FourOmni => "gpt-4o",
86 }
87 }
88
89 pub fn max_token_count(&self) -> usize {
90 match self {
91 Model::ThreePointFiveTurbo => 4096,
92 Model::Four => 8192,
93 Model::FourTurbo => 128000,
94 Model::FourOmni => 128000,
95 }
96 }
97}
98
99#[derive(Debug, Serialize)]
100pub struct Request {
101 pub model: Model,
102 pub messages: Vec<RequestMessage>,
103 pub stream: bool,
104 pub stop: Vec<String>,
105 pub temperature: f32,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 pub tool_choice: Option<String>,
108 #[serde(skip_serializing_if = "Vec::is_empty")]
109 pub tools: Vec<ToolDefinition>,
110}
111
112#[derive(Debug, Serialize)]
113pub struct FunctionDefinition {
114 pub name: String,
115 pub description: Option<String>,
116 pub parameters: Option<Map<String, Value>>,
117}
118
119#[derive(Serialize, Debug)]
120#[serde(tag = "type", rename_all = "snake_case")]
121pub enum ToolDefinition {
122 #[allow(dead_code)]
123 Function { function: FunctionDefinition },
124}
125
126#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
127#[serde(tag = "role", rename_all = "lowercase")]
128pub enum RequestMessage {
129 Assistant {
130 content: Option<String>,
131 #[serde(default, skip_serializing_if = "Vec::is_empty")]
132 tool_calls: Vec<ToolCall>,
133 },
134 User {
135 content: String,
136 },
137 System {
138 content: String,
139 },
140 Tool {
141 content: String,
142 tool_call_id: String,
143 },
144}
145
146#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
147pub struct ToolCall {
148 pub id: String,
149 #[serde(flatten)]
150 pub content: ToolCallContent,
151}
152
153#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
154#[serde(tag = "type", rename_all = "lowercase")]
155pub enum ToolCallContent {
156 Function { function: FunctionContent },
157}
158
159#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
160pub struct FunctionContent {
161 pub name: String,
162 pub arguments: String,
163}
164
165#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
166pub struct ResponseMessageDelta {
167 pub role: Option<Role>,
168 pub content: Option<String>,
169 #[serde(default, skip_serializing_if = "Vec::is_empty")]
170 pub tool_calls: Vec<ToolCallChunk>,
171}
172
173#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
174pub struct ToolCallChunk {
175 pub index: usize,
176 pub id: Option<String>,
177
178 // There is also an optional `type` field that would determine if a
179 // function is there. Sometimes this streams in with the `function` before
180 // it streams in the `type`
181 pub function: Option<FunctionChunk>,
182}
183
184#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
185pub struct FunctionChunk {
186 pub name: Option<String>,
187 pub arguments: Option<String>,
188}
189
190#[derive(Deserialize, Debug)]
191pub struct Usage {
192 pub prompt_tokens: u32,
193 pub completion_tokens: u32,
194 pub total_tokens: u32,
195}
196
197#[derive(Deserialize, Debug)]
198pub struct ChoiceDelta {
199 pub index: u32,
200 pub delta: ResponseMessageDelta,
201 pub finish_reason: Option<String>,
202}
203
204#[derive(Deserialize, Debug)]
205pub struct ResponseStreamEvent {
206 pub created: u32,
207 pub model: String,
208 pub choices: Vec<ChoiceDelta>,
209 pub usage: Option<Usage>,
210}
211
212pub async fn stream_completion(
213 client: &dyn HttpClient,
214 api_url: &str,
215 api_key: &str,
216 request: Request,
217 low_speed_timeout: Option<Duration>,
218) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
219 let uri = format!("{api_url}/chat/completions");
220 let mut request_builder = HttpRequest::builder()
221 .method(Method::POST)
222 .uri(uri)
223 .header("Content-Type", "application/json")
224 .header("Authorization", format!("Bearer {}", api_key));
225
226 if let Some(low_speed_timeout) = low_speed_timeout {
227 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
228 };
229
230 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
231 let mut response = client.send(request).await?;
232 if response.status().is_success() {
233 let reader = BufReader::new(response.into_body());
234 Ok(reader
235 .lines()
236 .filter_map(|line| async move {
237 match line {
238 Ok(line) => {
239 let line = line.strip_prefix("data: ")?;
240 if line == "[DONE]" {
241 None
242 } else {
243 match serde_json::from_str(line) {
244 Ok(response) => Some(Ok(response)),
245 Err(error) => Some(Err(anyhow!(error))),
246 }
247 }
248 }
249 Err(error) => Some(Err(anyhow!(error))),
250 }
251 })
252 .boxed())
253 } else {
254 let mut body = String::new();
255 response.body_mut().read_to_string(&mut body).await?;
256
257 #[derive(Deserialize)]
258 struct OpenAiResponse {
259 error: OpenAiError,
260 }
261
262 #[derive(Deserialize)]
263 struct OpenAiError {
264 message: String,
265 }
266
267 match serde_json::from_str::<OpenAiResponse>(&body) {
268 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
269 "Failed to connect to OpenAI API: {}",
270 response.error.message,
271 )),
272
273 _ => Err(anyhow!(
274 "Failed to connect to OpenAI API: {} {}",
275 response.status(),
276 body,
277 )),
278 }
279 }
280}
281
282#[derive(Copy, Clone, Serialize, Deserialize)]
283pub enum OpenAiEmbeddingModel {
284 #[serde(rename = "text-embedding-3-small")]
285 TextEmbedding3Small,
286 #[serde(rename = "text-embedding-3-large")]
287 TextEmbedding3Large,
288}
289
290#[derive(Serialize)]
291struct OpenAiEmbeddingRequest<'a> {
292 model: OpenAiEmbeddingModel,
293 input: Vec<&'a str>,
294}
295
296#[derive(Deserialize)]
297pub struct OpenAiEmbeddingResponse {
298 pub data: Vec<OpenAiEmbedding>,
299}
300
301#[derive(Deserialize)]
302pub struct OpenAiEmbedding {
303 pub embedding: Vec<f32>,
304}
305
306pub fn embed<'a>(
307 client: &dyn HttpClient,
308 api_url: &str,
309 api_key: &str,
310 model: OpenAiEmbeddingModel,
311 texts: impl IntoIterator<Item = &'a str>,
312) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
313 let uri = format!("{api_url}/embeddings");
314
315 let request = OpenAiEmbeddingRequest {
316 model,
317 input: texts.into_iter().collect(),
318 };
319 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
320 let request = HttpRequest::builder()
321 .method(Method::POST)
322 .uri(uri)
323 .header("Content-Type", "application/json")
324 .header("Authorization", format!("Bearer {}", api_key))
325 .body(body)
326 .map(|request| client.send(request));
327
328 async move {
329 let mut response = request?.await?;
330 let mut body = String::new();
331 response.body_mut().read_to_string(&mut body).await?;
332
333 if response.status().is_success() {
334 let response: OpenAiEmbeddingResponse =
335 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
336 Ok(response)
337 } else {
338 Err(anyhow!(
339 "error during embedding, status: {:?}, body: {:?}",
340 response.status(),
341 body
342 ))
343 }
344 }
345}