1use anyhow::{anyhow, Context, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value};
5use std::{convert::TryFrom, future::Future};
6use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
7
8pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
9
10#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 User,
14 Assistant,
15 System,
16 Tool,
17}
18
19impl TryFrom<String> for Role {
20 type Error = anyhow::Error;
21
22 fn try_from(value: String) -> Result<Self> {
23 match value.as_str() {
24 "user" => Ok(Self::User),
25 "assistant" => Ok(Self::Assistant),
26 "system" => Ok(Self::System),
27 "tool" => Ok(Self::Tool),
28 _ => Err(anyhow!("invalid role '{value}'")),
29 }
30 }
31}
32
33impl From<Role> for String {
34 fn from(val: Role) -> Self {
35 match val {
36 Role::User => "user".to_owned(),
37 Role::Assistant => "assistant".to_owned(),
38 Role::System => "system".to_owned(),
39 Role::Tool => "tool".to_owned(),
40 }
41 }
42}
43
44#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
45#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
46pub enum Model {
47 #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
48 ThreePointFiveTurbo,
49 #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
50 Four,
51 #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
52 #[default]
53 FourTurbo,
54}
55
56impl Model {
57 pub fn from_id(id: &str) -> Result<Self> {
58 match id {
59 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
60 "gpt-4" => Ok(Self::Four),
61 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
62 _ => Err(anyhow!("invalid model id")),
63 }
64 }
65
66 pub fn id(&self) -> &'static str {
67 match self {
68 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
69 Self::Four => "gpt-4",
70 Self::FourTurbo => "gpt-4-turbo-preview",
71 }
72 }
73
74 pub fn display_name(&self) -> &'static str {
75 match self {
76 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
77 Self::Four => "gpt-4",
78 Self::FourTurbo => "gpt-4-turbo",
79 }
80 }
81
82 pub fn max_token_count(&self) -> usize {
83 match self {
84 Model::ThreePointFiveTurbo => 4096,
85 Model::Four => 8192,
86 Model::FourTurbo => 128000,
87 }
88 }
89}
90
91#[derive(Debug, Serialize)]
92pub struct Request {
93 pub model: Model,
94 pub messages: Vec<RequestMessage>,
95 pub stream: bool,
96 pub stop: Vec<String>,
97 pub temperature: f32,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub tool_choice: Option<String>,
100 #[serde(skip_serializing_if = "Vec::is_empty")]
101 pub tools: Vec<ToolDefinition>,
102}
103
104#[derive(Debug, Serialize)]
105pub struct FunctionDefinition {
106 pub name: String,
107 pub description: Option<String>,
108 pub parameters: Option<Map<String, Value>>,
109}
110
111#[derive(Serialize, Debug)]
112#[serde(tag = "type", rename_all = "snake_case")]
113pub enum ToolDefinition {
114 #[allow(dead_code)]
115 Function { function: FunctionDefinition },
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119#[serde(tag = "role", rename_all = "lowercase")]
120pub enum RequestMessage {
121 Assistant {
122 content: Option<String>,
123 #[serde(default, skip_serializing_if = "Vec::is_empty")]
124 tool_calls: Vec<ToolCall>,
125 },
126 User {
127 content: String,
128 },
129 System {
130 content: String,
131 },
132 Tool {
133 content: String,
134 tool_call_id: String,
135 },
136}
137
138#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
139pub struct ToolCall {
140 pub id: String,
141 #[serde(flatten)]
142 pub content: ToolCallContent,
143}
144
145#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
146#[serde(tag = "type", rename_all = "lowercase")]
147pub enum ToolCallContent {
148 Function { function: FunctionContent },
149}
150
151#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
152pub struct FunctionContent {
153 pub name: String,
154 pub arguments: String,
155}
156
157#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
158pub struct ResponseMessageDelta {
159 pub role: Option<Role>,
160 pub content: Option<String>,
161 #[serde(default, skip_serializing_if = "Vec::is_empty")]
162 pub tool_calls: Vec<ToolCallChunk>,
163}
164
165#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
166pub struct ToolCallChunk {
167 pub index: usize,
168 pub id: Option<String>,
169
170 // There is also an optional `type` field that would determine if a
171 // function is there. Sometimes this streams in with the `function` before
172 // it streams in the `type`
173 pub function: Option<FunctionChunk>,
174}
175
176#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
177pub struct FunctionChunk {
178 pub name: Option<String>,
179 pub arguments: Option<String>,
180}
181
182#[derive(Deserialize, Debug)]
183pub struct Usage {
184 pub prompt_tokens: u32,
185 pub completion_tokens: u32,
186 pub total_tokens: u32,
187}
188
189#[derive(Deserialize, Debug)]
190pub struct ChoiceDelta {
191 pub index: u32,
192 pub delta: ResponseMessageDelta,
193 pub finish_reason: Option<String>,
194}
195
196#[derive(Deserialize, Debug)]
197pub struct ResponseStreamEvent {
198 pub created: u32,
199 pub model: String,
200 pub choices: Vec<ChoiceDelta>,
201 pub usage: Option<Usage>,
202}
203
204pub async fn stream_completion(
205 client: &dyn HttpClient,
206 api_url: &str,
207 api_key: &str,
208 request: Request,
209) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
210 let uri = format!("{api_url}/chat/completions");
211 let request = HttpRequest::builder()
212 .method(Method::POST)
213 .uri(uri)
214 .header("Content-Type", "application/json")
215 .header("Authorization", format!("Bearer {}", api_key))
216 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
217 let mut response = client.send(request).await?;
218 if response.status().is_success() {
219 let reader = BufReader::new(response.into_body());
220 Ok(reader
221 .lines()
222 .filter_map(|line| async move {
223 match line {
224 Ok(line) => {
225 let line = line.strip_prefix("data: ")?;
226 if line == "[DONE]" {
227 None
228 } else {
229 match serde_json::from_str(line) {
230 Ok(response) => Some(Ok(response)),
231 Err(error) => Some(Err(anyhow!(error))),
232 }
233 }
234 }
235 Err(error) => Some(Err(anyhow!(error))),
236 }
237 })
238 .boxed())
239 } else {
240 let mut body = String::new();
241 response.body_mut().read_to_string(&mut body).await?;
242
243 #[derive(Deserialize)]
244 struct OpenAiResponse {
245 error: OpenAiError,
246 }
247
248 #[derive(Deserialize)]
249 struct OpenAiError {
250 message: String,
251 }
252
253 match serde_json::from_str::<OpenAiResponse>(&body) {
254 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
255 "Failed to connect to OpenAI API: {}",
256 response.error.message,
257 )),
258
259 _ => Err(anyhow!(
260 "Failed to connect to OpenAI API: {} {}",
261 response.status(),
262 body,
263 )),
264 }
265 }
266}
267
268#[derive(Copy, Clone, Serialize, Deserialize)]
269pub enum OpenAiEmbeddingModel {
270 #[serde(rename = "text-embedding-3-small")]
271 TextEmbedding3Small,
272 #[serde(rename = "text-embedding-3-large")]
273 TextEmbedding3Large,
274}
275
276#[derive(Serialize)]
277struct OpenAiEmbeddingRequest<'a> {
278 model: OpenAiEmbeddingModel,
279 input: Vec<&'a str>,
280}
281
282#[derive(Deserialize)]
283pub struct OpenAiEmbeddingResponse {
284 pub data: Vec<OpenAiEmbedding>,
285}
286
287#[derive(Deserialize)]
288pub struct OpenAiEmbedding {
289 pub embedding: Vec<f32>,
290}
291
292pub fn embed<'a>(
293 client: &dyn HttpClient,
294 api_url: &str,
295 api_key: &str,
296 model: OpenAiEmbeddingModel,
297 texts: impl IntoIterator<Item = &'a str>,
298) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
299 let uri = format!("{api_url}/embeddings");
300
301 let request = OpenAiEmbeddingRequest {
302 model,
303 input: texts.into_iter().collect(),
304 };
305 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
306 let request = HttpRequest::builder()
307 .method(Method::POST)
308 .uri(uri)
309 .header("Content-Type", "application/json")
310 .header("Authorization", format!("Bearer {}", api_key))
311 .body(body)
312 .map(|request| client.send(request));
313
314 async move {
315 let mut response = request?.await?;
316 let mut body = String::new();
317 response.body_mut().read_to_string(&mut body).await?;
318
319 if response.status().is_success() {
320 let response: OpenAiEmbeddingResponse =
321 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
322 Ok(response)
323 } else {
324 Err(anyhow!(
325 "error during embedding, status: {:?}, body: {:?}",
326 response.status(),
327 body
328 ))
329 }
330 }
331}