1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use serde::{Deserialize, Serialize};
6use serde_json::{Map, Value};
7use std::{convert::TryFrom, future::Future, time::Duration};
8use strum::EnumIter;
9
10pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
11
12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15 User,
16 Assistant,
17 System,
18 Tool,
19}
20
21impl TryFrom<String> for Role {
22 type Error = anyhow::Error;
23
24 fn try_from(value: String) -> Result<Self> {
25 match value.as_str() {
26 "user" => Ok(Self::User),
27 "assistant" => Ok(Self::Assistant),
28 "system" => Ok(Self::System),
29 "tool" => Ok(Self::Tool),
30 _ => Err(anyhow!("invalid role '{value}'")),
31 }
32 }
33}
34
35impl From<Role> for String {
36 fn from(val: Role) -> Self {
37 match val {
38 Role::User => "user".to_owned(),
39 Role::Assistant => "assistant".to_owned(),
40 Role::System => "system".to_owned(),
41 Role::Tool => "tool".to_owned(),
42 }
43 }
44}
45
46#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
47#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
48pub enum Model {
49 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
50 ThreePointFiveTurbo,
51 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
52 Four,
53 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
54 FourTurbo,
55 #[serde(rename = "gpt-4o", alias = "gpt-4o-2024-05-13")]
56 #[default]
57 FourOmni,
58 #[serde(rename = "custom")]
59 Custom { name: String, max_tokens: usize },
60}
61
62impl Model {
63 pub fn from_id(id: &str) -> Result<Self> {
64 match id {
65 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
66 "gpt-4" => Ok(Self::Four),
67 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
68 "gpt-4o" => Ok(Self::FourOmni),
69 _ => Err(anyhow!("invalid model id")),
70 }
71 }
72
73 pub fn id(&self) -> &'static str {
74 match self {
75 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
76 Self::Four => "gpt-4",
77 Self::FourTurbo => "gpt-4-turbo-preview",
78 Self::FourOmni => "gpt-4o",
79 Self::Custom { .. } => "custom",
80 }
81 }
82
83 pub fn display_name(&self) -> &str {
84 match self {
85 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
86 Self::Four => "gpt-4",
87 Self::FourTurbo => "gpt-4-turbo",
88 Self::FourOmni => "gpt-4o",
89 Self::Custom { name, .. } => name,
90 }
91 }
92
93 pub fn max_token_count(&self) -> usize {
94 match self {
95 Model::ThreePointFiveTurbo => 4096,
96 Model::Four => 8192,
97 Model::FourTurbo => 128000,
98 Model::FourOmni => 128000,
99 Model::Custom { max_tokens, .. } => *max_tokens,
100 }
101 }
102}
103
104fn serialize_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
105where
106 S: serde::Serializer,
107{
108 match model {
109 Model::Custom { name, .. } => serializer.serialize_str(name),
110 _ => serializer.serialize_str(model.id()),
111 }
112}
113
114#[derive(Debug, Serialize)]
115pub struct Request {
116 #[serde(serialize_with = "serialize_model")]
117 pub model: Model,
118 pub messages: Vec<RequestMessage>,
119 pub stream: bool,
120 pub stop: Vec<String>,
121 pub temperature: f32,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 pub tool_choice: Option<String>,
124 #[serde(skip_serializing_if = "Vec::is_empty")]
125 pub tools: Vec<ToolDefinition>,
126}
127
128#[derive(Debug, Serialize)]
129pub struct FunctionDefinition {
130 pub name: String,
131 pub description: Option<String>,
132 pub parameters: Option<Map<String, Value>>,
133}
134
135#[derive(Serialize, Debug)]
136#[serde(tag = "type", rename_all = "snake_case")]
137pub enum ToolDefinition {
138 #[allow(dead_code)]
139 Function { function: FunctionDefinition },
140}
141
142#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
143#[serde(tag = "role", rename_all = "lowercase")]
144pub enum RequestMessage {
145 Assistant {
146 content: Option<String>,
147 #[serde(default, skip_serializing_if = "Vec::is_empty")]
148 tool_calls: Vec<ToolCall>,
149 },
150 User {
151 content: String,
152 },
153 System {
154 content: String,
155 },
156 Tool {
157 content: String,
158 tool_call_id: String,
159 },
160}
161
162#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
163pub struct ToolCall {
164 pub id: String,
165 #[serde(flatten)]
166 pub content: ToolCallContent,
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
170#[serde(tag = "type", rename_all = "lowercase")]
171pub enum ToolCallContent {
172 Function { function: FunctionContent },
173}
174
175#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
176pub struct FunctionContent {
177 pub name: String,
178 pub arguments: String,
179}
180
181#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
182pub struct ResponseMessageDelta {
183 pub role: Option<Role>,
184 pub content: Option<String>,
185 #[serde(default, skip_serializing_if = "Vec::is_empty")]
186 pub tool_calls: Vec<ToolCallChunk>,
187}
188
189#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
190pub struct ToolCallChunk {
191 pub index: usize,
192 pub id: Option<String>,
193
194 // There is also an optional `type` field that would determine if a
195 // function is there. Sometimes this streams in with the `function` before
196 // it streams in the `type`
197 pub function: Option<FunctionChunk>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
201pub struct FunctionChunk {
202 pub name: Option<String>,
203 pub arguments: Option<String>,
204}
205
206#[derive(Deserialize, Debug)]
207pub struct Usage {
208 pub prompt_tokens: u32,
209 pub completion_tokens: u32,
210 pub total_tokens: u32,
211}
212
213#[derive(Deserialize, Debug)]
214pub struct ChoiceDelta {
215 pub index: u32,
216 pub delta: ResponseMessageDelta,
217 pub finish_reason: Option<String>,
218}
219
220#[derive(Deserialize, Debug)]
221pub struct ResponseStreamEvent {
222 pub created: u32,
223 pub model: String,
224 pub choices: Vec<ChoiceDelta>,
225 pub usage: Option<Usage>,
226}
227
228pub async fn stream_completion(
229 client: &dyn HttpClient,
230 api_url: &str,
231 api_key: &str,
232 request: Request,
233 low_speed_timeout: Option<Duration>,
234) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
235 let uri = format!("{api_url}/chat/completions");
236 let mut request_builder = HttpRequest::builder()
237 .method(Method::POST)
238 .uri(uri)
239 .header("Content-Type", "application/json")
240 .header("Authorization", format!("Bearer {}", api_key));
241
242 if let Some(low_speed_timeout) = low_speed_timeout {
243 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
244 };
245
246 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
247 let mut response = client.send(request).await?;
248 if response.status().is_success() {
249 let reader = BufReader::new(response.into_body());
250 Ok(reader
251 .lines()
252 .filter_map(|line| async move {
253 match line {
254 Ok(line) => {
255 let line = line.strip_prefix("data: ")?;
256 if line == "[DONE]" {
257 None
258 } else {
259 match serde_json::from_str(line) {
260 Ok(response) => Some(Ok(response)),
261 Err(error) => Some(Err(anyhow!(error))),
262 }
263 }
264 }
265 Err(error) => Some(Err(anyhow!(error))),
266 }
267 })
268 .boxed())
269 } else {
270 let mut body = String::new();
271 response.body_mut().read_to_string(&mut body).await?;
272
273 #[derive(Deserialize)]
274 struct OpenAiResponse {
275 error: OpenAiError,
276 }
277
278 #[derive(Deserialize)]
279 struct OpenAiError {
280 message: String,
281 }
282
283 match serde_json::from_str::<OpenAiResponse>(&body) {
284 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
285 "Failed to connect to OpenAI API: {}",
286 response.error.message,
287 )),
288
289 _ => Err(anyhow!(
290 "Failed to connect to OpenAI API: {} {}",
291 response.status(),
292 body,
293 )),
294 }
295 }
296}
297
298#[derive(Copy, Clone, Serialize, Deserialize)]
299pub enum OpenAiEmbeddingModel {
300 #[serde(rename = "text-embedding-3-small")]
301 TextEmbedding3Small,
302 #[serde(rename = "text-embedding-3-large")]
303 TextEmbedding3Large,
304}
305
306#[derive(Serialize)]
307struct OpenAiEmbeddingRequest<'a> {
308 model: OpenAiEmbeddingModel,
309 input: Vec<&'a str>,
310}
311
312#[derive(Deserialize)]
313pub struct OpenAiEmbeddingResponse {
314 pub data: Vec<OpenAiEmbedding>,
315}
316
317#[derive(Deserialize)]
318pub struct OpenAiEmbedding {
319 pub embedding: Vec<f32>,
320}
321
322pub fn embed<'a>(
323 client: &dyn HttpClient,
324 api_url: &str,
325 api_key: &str,
326 model: OpenAiEmbeddingModel,
327 texts: impl IntoIterator<Item = &'a str>,
328) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
329 let uri = format!("{api_url}/embeddings");
330
331 let request = OpenAiEmbeddingRequest {
332 model,
333 input: texts.into_iter().collect(),
334 };
335 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
336 let request = HttpRequest::builder()
337 .method(Method::POST)
338 .uri(uri)
339 .header("Content-Type", "application/json")
340 .header("Authorization", format!("Bearer {}", api_key))
341 .body(body)
342 .map(|request| client.send(request));
343
344 async move {
345 let mut response = request?.await?;
346 let mut body = String::new();
347 response.body_mut().read_to_string(&mut body).await?;
348
349 if response.status().is_success() {
350 let response: OpenAiEmbeddingResponse =
351 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
352 Ok(response)
353 } else {
354 Err(anyhow!(
355 "error during embedding, status: {:?}, body: {:?}",
356 response.status(),
357 body
358 ))
359 }
360 }
361}