1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use isahc::config::Configurable;
4use serde::{Deserialize, Serialize};
5use serde_json::{Map, Value};
6use std::time::Duration;
7use std::{convert::TryFrom, future::Future};
8use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
9
10pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
11
12#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
13#[serde(rename_all = "lowercase")]
14pub enum Role {
15 User,
16 Assistant,
17 System,
18 Tool,
19}
20
21impl TryFrom<String> for Role {
22 type Error = anyhow::Error;
23
24 fn try_from(value: String) -> Result<Self> {
25 match value.as_str() {
26 "user" => Ok(Self::User),
27 "assistant" => Ok(Self::Assistant),
28 "system" => Ok(Self::System),
29 "tool" => Ok(Self::Tool),
30 _ => Err(anyhow!("invalid role '{value}'")),
31 }
32 }
33}
34
35impl From<Role> for String {
36 fn from(val: Role) -> Self {
37 match val {
38 Role::User => "user".to_owned(),
39 Role::Assistant => "assistant".to_owned(),
40 Role::System => "system".to_owned(),
41 Role::Tool => "tool".to_owned(),
42 }
43 }
44}
45
46#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
47#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
48pub enum Model {
49 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
50 ThreePointFiveTurbo,
51 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
52 Four,
53 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
54 #[default]
55 FourTurbo,
56}
57
58impl Model {
59 pub fn from_id(id: &str) -> Result<Self> {
60 match id {
61 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
62 "gpt-4" => Ok(Self::Four),
63 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
64 _ => Err(anyhow!("invalid model id")),
65 }
66 }
67
68 pub fn id(&self) -> &'static str {
69 match self {
70 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
71 Self::Four => "gpt-4",
72 Self::FourTurbo => "gpt-4-turbo-preview",
73 }
74 }
75
76 pub fn display_name(&self) -> &'static str {
77 match self {
78 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
79 Self::Four => "gpt-4",
80 Self::FourTurbo => "gpt-4-turbo",
81 }
82 }
83
84 pub fn max_token_count(&self) -> usize {
85 match self {
86 Model::ThreePointFiveTurbo => 4096,
87 Model::Four => 8192,
88 Model::FourTurbo => 128000,
89 }
90 }
91}
92
93#[derive(Debug, Serialize)]
94pub struct Request {
95 pub model: Model,
96 pub messages: Vec<RequestMessage>,
97 pub stream: bool,
98 pub stop: Vec<String>,
99 pub temperature: f32,
100 #[serde(skip_serializing_if = "Option::is_none")]
101 pub tool_choice: Option<String>,
102 #[serde(skip_serializing_if = "Vec::is_empty")]
103 pub tools: Vec<ToolDefinition>,
104}
105
106#[derive(Debug, Serialize)]
107pub struct FunctionDefinition {
108 pub name: String,
109 pub description: Option<String>,
110 pub parameters: Option<Map<String, Value>>,
111}
112
113#[derive(Serialize, Debug)]
114#[serde(tag = "type", rename_all = "snake_case")]
115pub enum ToolDefinition {
116 #[allow(dead_code)]
117 Function { function: FunctionDefinition },
118}
119
120#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
121#[serde(tag = "role", rename_all = "lowercase")]
122pub enum RequestMessage {
123 Assistant {
124 content: Option<String>,
125 #[serde(default, skip_serializing_if = "Vec::is_empty")]
126 tool_calls: Vec<ToolCall>,
127 },
128 User {
129 content: String,
130 },
131 System {
132 content: String,
133 },
134 Tool {
135 content: String,
136 tool_call_id: String,
137 },
138}
139
140#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
141pub struct ToolCall {
142 pub id: String,
143 #[serde(flatten)]
144 pub content: ToolCallContent,
145}
146
147#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
148#[serde(tag = "type", rename_all = "lowercase")]
149pub enum ToolCallContent {
150 Function { function: FunctionContent },
151}
152
153#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
154pub struct FunctionContent {
155 pub name: String,
156 pub arguments: String,
157}
158
159#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
160pub struct ResponseMessageDelta {
161 pub role: Option<Role>,
162 pub content: Option<String>,
163 #[serde(default, skip_serializing_if = "Vec::is_empty")]
164 pub tool_calls: Vec<ToolCallChunk>,
165}
166
167#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
168pub struct ToolCallChunk {
169 pub index: usize,
170 pub id: Option<String>,
171
172 // There is also an optional `type` field that would determine if a
173 // function is there. Sometimes this streams in with the `function` before
174 // it streams in the `type`
175 pub function: Option<FunctionChunk>,
176}
177
178#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
179pub struct FunctionChunk {
180 pub name: Option<String>,
181 pub arguments: Option<String>,
182}
183
184#[derive(Deserialize, Debug)]
185pub struct Usage {
186 pub prompt_tokens: u32,
187 pub completion_tokens: u32,
188 pub total_tokens: u32,
189}
190
191#[derive(Deserialize, Debug)]
192pub struct ChoiceDelta {
193 pub index: u32,
194 pub delta: ResponseMessageDelta,
195 pub finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug)]
199pub struct ResponseStreamEvent {
200 pub created: u32,
201 pub model: String,
202 pub choices: Vec<ChoiceDelta>,
203 pub usage: Option<Usage>,
204}
205
206pub async fn stream_completion(
207 client: &dyn HttpClient,
208 api_url: &str,
209 api_key: &str,
210 request: Request,
211 low_speed_timeout: Option<Duration>,
212) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
213 let uri = format!("{api_url}/chat/completions");
214 let mut request_builder = HttpRequest::builder()
215 .method(Method::POST)
216 .uri(uri)
217 .header("Content-Type", "application/json")
218 .header("Authorization", format!("Bearer {}", api_key));
219
220 if let Some(low_speed_timeout) = low_speed_timeout {
221 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
222 };
223
224 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
225 let mut response = client.send(request).await?;
226 if response.status().is_success() {
227 let reader = BufReader::new(response.into_body());
228 Ok(reader
229 .lines()
230 .filter_map(|line| async move {
231 match line {
232 Ok(line) => {
233 let line = line.strip_prefix("data: ")?;
234 if line == "[DONE]" {
235 None
236 } else {
237 match serde_json::from_str(line) {
238 Ok(response) => Some(Ok(response)),
239 Err(error) => Some(Err(anyhow!(error))),
240 }
241 }
242 }
243 Err(error) => Some(Err(anyhow!(error))),
244 }
245 })
246 .boxed())
247 } else {
248 let mut body = String::new();
249 response.body_mut().read_to_string(&mut body).await?;
250
251 #[derive(Deserialize)]
252 struct OpenAiResponse {
253 error: OpenAiError,
254 }
255
256 #[derive(Deserialize)]
257 struct OpenAiError {
258 message: String,
259 }
260
261 match serde_json::from_str::<OpenAiResponse>(&body) {
262 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
263 "Failed to connect to OpenAI API: {}",
264 response.error.message,
265 )),
266
267 _ => Err(anyhow!(
268 "Failed to connect to OpenAI API: {} {}",
269 response.status(),
270 body,
271 )),
272 }
273 }
274}
275
276#[derive(Copy, Clone, Serialize, Deserialize)]
277pub enum OpenAiEmbeddingModel {
278 #[serde(rename = "text-embedding-3-small")]
279 TextEmbedding3Small,
280 #[serde(rename = "text-embedding-3-large")]
281 TextEmbedding3Large,
282}
283
284#[derive(Serialize)]
285struct OpenAiEmbeddingRequest<'a> {
286 model: OpenAiEmbeddingModel,
287 input: Vec<&'a str>,
288}
289
290#[derive(Deserialize)]
291pub struct OpenAiEmbeddingResponse {
292 pub data: Vec<OpenAiEmbedding>,
293}
294
295#[derive(Deserialize)]
296pub struct OpenAiEmbedding {
297 pub embedding: Vec<f32>,
298}
299
300pub fn embed<'a>(
301 client: &dyn HttpClient,
302 api_url: &str,
303 api_key: &str,
304 model: OpenAiEmbeddingModel,
305 texts: impl IntoIterator<Item = &'a str>,
306) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
307 let uri = format!("{api_url}/embeddings");
308
309 let request = OpenAiEmbeddingRequest {
310 model,
311 input: texts.into_iter().collect(),
312 };
313 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
314 let request = HttpRequest::builder()
315 .method(Method::POST)
316 .uri(uri)
317 .header("Content-Type", "application/json")
318 .header("Authorization", format!("Bearer {}", api_key))
319 .body(body)
320 .map(|request| client.send(request));
321
322 async move {
323 let mut response = request?.await?;
324 let mut body = String::new();
325 response.body_mut().read_to_string(&mut body).await?;
326
327 if response.status().is_success() {
328 let response: OpenAiEmbeddingResponse =
329 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
330 Ok(response)
331 } else {
332 Err(anyhow!(
333 "error during embedding, status: {:?}, body: {:?}",
334 response.status(),
335 body
336 ))
337 }
338 }
339}