1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, future::Future};
7use strum::EnumIter;
8
9pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
10
11fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
12 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
13}
14
15#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
16#[serde(rename_all = "lowercase")]
17pub enum Role {
18 User,
19 Assistant,
20 System,
21 Tool,
22}
23
24impl TryFrom<String> for Role {
25 type Error = anyhow::Error;
26
27 fn try_from(value: String) -> Result<Self> {
28 match value.as_str() {
29 "user" => Ok(Self::User),
30 "assistant" => Ok(Self::Assistant),
31 "system" => Ok(Self::System),
32 "tool" => Ok(Self::Tool),
33 _ => anyhow::bail!("invalid role '{value}'"),
34 }
35 }
36}
37
38impl From<Role> for String {
39 fn from(val: Role) -> Self {
40 match val {
41 Role::User => "user".to_owned(),
42 Role::Assistant => "assistant".to_owned(),
43 Role::System => "system".to_owned(),
44 Role::Tool => "tool".to_owned(),
45 }
46 }
47}
48
49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
50#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
51pub enum Model {
52 #[serde(rename = "gpt-3.5-turbo")]
53 ThreePointFiveTurbo,
54 #[serde(rename = "gpt-4")]
55 Four,
56 #[serde(rename = "gpt-4-turbo")]
57 FourTurbo,
58 #[serde(rename = "gpt-4o")]
59 #[default]
60 FourOmni,
61 #[serde(rename = "gpt-4o-mini")]
62 FourOmniMini,
63 #[serde(rename = "gpt-4.1")]
64 FourPointOne,
65 #[serde(rename = "gpt-4.1-mini")]
66 FourPointOneMini,
67 #[serde(rename = "gpt-4.1-nano")]
68 FourPointOneNano,
69 #[serde(rename = "o1")]
70 O1,
71 #[serde(rename = "o3-mini")]
72 O3Mini,
73 #[serde(rename = "o3")]
74 O3,
75 #[serde(rename = "o4-mini")]
76 O4Mini,
77 #[serde(rename = "gpt-5")]
78 Five,
79 #[serde(rename = "gpt-5-mini")]
80 FiveMini,
81 #[serde(rename = "gpt-5-nano")]
82 FiveNano,
83
84 #[serde(rename = "custom")]
85 Custom {
86 name: String,
87 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
88 display_name: Option<String>,
89 max_tokens: u64,
90 max_output_tokens: Option<u64>,
91 max_completion_tokens: Option<u64>,
92 reasoning_effort: Option<ReasoningEffort>,
93 },
94}
95
96impl Model {
97 pub fn default_fast() -> Self {
98 // TODO: Replace with FiveMini since all other models are deprecated
99 Self::FourPointOneMini
100 }
101
102 pub fn from_id(id: &str) -> Result<Self> {
103 match id {
104 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
105 "gpt-4" => Ok(Self::Four),
106 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
107 "gpt-4o" => Ok(Self::FourOmni),
108 "gpt-4o-mini" => Ok(Self::FourOmniMini),
109 "gpt-4.1" => Ok(Self::FourPointOne),
110 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
111 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
112 "o1" => Ok(Self::O1),
113 "o3-mini" => Ok(Self::O3Mini),
114 "o3" => Ok(Self::O3),
115 "o4-mini" => Ok(Self::O4Mini),
116 "gpt-5" => Ok(Self::Five),
117 "gpt-5-mini" => Ok(Self::FiveMini),
118 "gpt-5-nano" => Ok(Self::FiveNano),
119 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
120 }
121 }
122
123 pub fn id(&self) -> &str {
124 match self {
125 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
126 Self::Four => "gpt-4",
127 Self::FourTurbo => "gpt-4-turbo",
128 Self::FourOmni => "gpt-4o",
129 Self::FourOmniMini => "gpt-4o-mini",
130 Self::FourPointOne => "gpt-4.1",
131 Self::FourPointOneMini => "gpt-4.1-mini",
132 Self::FourPointOneNano => "gpt-4.1-nano",
133 Self::O1 => "o1",
134 Self::O3Mini => "o3-mini",
135 Self::O3 => "o3",
136 Self::O4Mini => "o4-mini",
137 Self::Five => "gpt-5",
138 Self::FiveMini => "gpt-5-mini",
139 Self::FiveNano => "gpt-5-nano",
140 Self::Custom { name, .. } => name,
141 }
142 }
143
144 pub fn display_name(&self) -> &str {
145 match self {
146 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
147 Self::Four => "gpt-4",
148 Self::FourTurbo => "gpt-4-turbo",
149 Self::FourOmni => "gpt-4o",
150 Self::FourOmniMini => "gpt-4o-mini",
151 Self::FourPointOne => "gpt-4.1",
152 Self::FourPointOneMini => "gpt-4.1-mini",
153 Self::FourPointOneNano => "gpt-4.1-nano",
154 Self::O1 => "o1",
155 Self::O3Mini => "o3-mini",
156 Self::O3 => "o3",
157 Self::O4Mini => "o4-mini",
158 Self::Five => "gpt-5",
159 Self::FiveMini => "gpt-5-mini",
160 Self::FiveNano => "gpt-5-nano",
161 Self::Custom {
162 name, display_name, ..
163 } => display_name.as_ref().unwrap_or(name),
164 }
165 }
166
167 pub fn max_token_count(&self) -> u64 {
168 match self {
169 Self::ThreePointFiveTurbo => 16_385,
170 Self::Four => 8_192,
171 Self::FourTurbo => 128_000,
172 Self::FourOmni => 128_000,
173 Self::FourOmniMini => 128_000,
174 Self::FourPointOne => 1_047_576,
175 Self::FourPointOneMini => 1_047_576,
176 Self::FourPointOneNano => 1_047_576,
177 Self::O1 => 200_000,
178 Self::O3Mini => 200_000,
179 Self::O3 => 200_000,
180 Self::O4Mini => 200_000,
181 Self::Five => 272_000,
182 Self::FiveMini => 272_000,
183 Self::FiveNano => 272_000,
184 Self::Custom { max_tokens, .. } => *max_tokens,
185 }
186 }
187
188 pub fn max_output_tokens(&self) -> Option<u64> {
189 match self {
190 Self::Custom {
191 max_output_tokens, ..
192 } => *max_output_tokens,
193 Self::ThreePointFiveTurbo => Some(4_096),
194 Self::Four => Some(8_192),
195 Self::FourTurbo => Some(4_096),
196 Self::FourOmni => Some(16_384),
197 Self::FourOmniMini => Some(16_384),
198 Self::FourPointOne => Some(32_768),
199 Self::FourPointOneMini => Some(32_768),
200 Self::FourPointOneNano => Some(32_768),
201 Self::O1 => Some(100_000),
202 Self::O3Mini => Some(100_000),
203 Self::O3 => Some(100_000),
204 Self::O4Mini => Some(100_000),
205 Self::Five => Some(128_000),
206 Self::FiveMini => Some(128_000),
207 Self::FiveNano => Some(128_000),
208 }
209 }
210
211 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
212 match self {
213 Self::Custom {
214 reasoning_effort, ..
215 } => reasoning_effort.to_owned(),
216 _ => None,
217 }
218 }
219
220 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
221 ///
222 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
223 pub fn supports_parallel_tool_calls(&self) -> bool {
224 match self {
225 Self::ThreePointFiveTurbo
226 | Self::Four
227 | Self::FourTurbo
228 | Self::FourOmni
229 | Self::FourOmniMini
230 | Self::FourPointOne
231 | Self::FourPointOneMini
232 | Self::FourPointOneNano
233 | Self::Five
234 | Self::FiveMini
235 | Self::FiveNano => true,
236 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
237 }
238 }
239}
240
241#[derive(Debug, Serialize, Deserialize)]
242pub struct Request {
243 pub model: String,
244 pub messages: Vec<RequestMessage>,
245 pub stream: bool,
246 #[serde(default, skip_serializing_if = "Option::is_none")]
247 pub max_completion_tokens: Option<u64>,
248 #[serde(default, skip_serializing_if = "Vec::is_empty")]
249 pub stop: Vec<String>,
250 pub temperature: f32,
251 #[serde(default, skip_serializing_if = "Option::is_none")]
252 pub tool_choice: Option<ToolChoice>,
253 /// Whether to enable parallel function calling during tool use.
254 #[serde(default, skip_serializing_if = "Option::is_none")]
255 pub parallel_tool_calls: Option<bool>,
256 #[serde(default, skip_serializing_if = "Vec::is_empty")]
257 pub tools: Vec<ToolDefinition>,
258 #[serde(default, skip_serializing_if = "Option::is_none")]
259 pub prompt_cache_key: Option<String>,
260 pub reasoning_effort: Option<ReasoningEffort>,
261}
262
263#[derive(Debug, Serialize, Deserialize)]
264#[serde(untagged)]
265pub enum ToolChoice {
266 Auto,
267 Required,
268 None,
269 Other(ToolDefinition),
270}
271
272#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
273#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
274#[serde(rename_all = "lowercase")]
275pub enum ReasoningEffort {
276 Minimal,
277 Low,
278 Medium,
279 High,
280}
281
282#[derive(Clone, Deserialize, Serialize, Debug)]
283#[serde(tag = "type", rename_all = "snake_case")]
284pub enum ToolDefinition {
285 #[allow(dead_code)]
286 Function { function: FunctionDefinition },
287}
288
289#[derive(Clone, Debug, Serialize, Deserialize)]
290pub struct FunctionDefinition {
291 pub name: String,
292 pub description: Option<String>,
293 pub parameters: Option<Value>,
294}
295
296#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
297#[serde(tag = "role", rename_all = "lowercase")]
298pub enum RequestMessage {
299 Assistant {
300 content: Option<MessageContent>,
301 #[serde(default, skip_serializing_if = "Vec::is_empty")]
302 tool_calls: Vec<ToolCall>,
303 },
304 User {
305 content: MessageContent,
306 },
307 System {
308 content: MessageContent,
309 },
310 Tool {
311 content: MessageContent,
312 tool_call_id: String,
313 },
314}
315
316#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
317#[serde(untagged)]
318pub enum MessageContent {
319 Plain(String),
320 Multipart(Vec<MessagePart>),
321}
322
323impl MessageContent {
324 pub fn empty() -> Self {
325 MessageContent::Multipart(vec![])
326 }
327
328 pub fn push_part(&mut self, part: MessagePart) {
329 match self {
330 MessageContent::Plain(text) => {
331 *self =
332 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
333 }
334 MessageContent::Multipart(parts) if parts.is_empty() => match part {
335 MessagePart::Text { text } => *self = MessageContent::Plain(text),
336 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
337 },
338 MessageContent::Multipart(parts) => parts.push(part),
339 }
340 }
341}
342
343impl From<Vec<MessagePart>> for MessageContent {
344 fn from(mut parts: Vec<MessagePart>) -> Self {
345 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
346 MessageContent::Plain(std::mem::take(text))
347 } else {
348 MessageContent::Multipart(parts)
349 }
350 }
351}
352
353#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
354#[serde(tag = "type")]
355pub enum MessagePart {
356 #[serde(rename = "text")]
357 Text { text: String },
358 #[serde(rename = "image_url")]
359 Image { image_url: ImageUrl },
360}
361
362#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
363pub struct ImageUrl {
364 pub url: String,
365 #[serde(skip_serializing_if = "Option::is_none")]
366 pub detail: Option<String>,
367}
368
369#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
370pub struct ToolCall {
371 pub id: String,
372 #[serde(flatten)]
373 pub content: ToolCallContent,
374}
375
376#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
377#[serde(tag = "type", rename_all = "lowercase")]
378pub enum ToolCallContent {
379 Function { function: FunctionContent },
380}
381
382#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
383pub struct FunctionContent {
384 pub name: String,
385 pub arguments: String,
386}
387
388#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
389pub struct ResponseMessageDelta {
390 pub role: Option<Role>,
391 pub content: Option<String>,
392 #[serde(default, skip_serializing_if = "is_none_or_empty")]
393 pub tool_calls: Option<Vec<ToolCallChunk>>,
394}
395
396#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
397pub struct ToolCallChunk {
398 pub index: usize,
399 pub id: Option<String>,
400
401 // There is also an optional `type` field that would determine if a
402 // function is there. Sometimes this streams in with the `function` before
403 // it streams in the `type`
404 pub function: Option<FunctionChunk>,
405}
406
407#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
408pub struct FunctionChunk {
409 pub name: Option<String>,
410 pub arguments: Option<String>,
411}
412
413#[derive(Serialize, Deserialize, Debug)]
414pub struct Usage {
415 pub prompt_tokens: u64,
416 pub completion_tokens: u64,
417 pub total_tokens: u64,
418}
419
420#[derive(Serialize, Deserialize, Debug)]
421pub struct ChoiceDelta {
422 pub index: u32,
423 pub delta: ResponseMessageDelta,
424 pub finish_reason: Option<String>,
425}
426
427#[derive(Serialize, Deserialize, Debug)]
428#[serde(untagged)]
429pub enum ResponseStreamResult {
430 Ok(ResponseStreamEvent),
431 Err { error: String },
432}
433
434#[derive(Serialize, Deserialize, Debug)]
435pub struct ResponseStreamEvent {
436 pub model: String,
437 pub choices: Vec<ChoiceDelta>,
438 pub usage: Option<Usage>,
439}
440
441pub async fn stream_completion(
442 client: &dyn HttpClient,
443 api_url: &str,
444 api_key: &str,
445 request: Request,
446) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
447 let uri = format!("{api_url}/chat/completions");
448 let request_builder = HttpRequest::builder()
449 .method(Method::POST)
450 .uri(uri)
451 .header("Content-Type", "application/json")
452 .header("Authorization", format!("Bearer {}", api_key));
453
454 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
455 let mut response = client.send(request).await?;
456 if response.status().is_success() {
457 let reader = BufReader::new(response.into_body());
458 Ok(reader
459 .lines()
460 .filter_map(|line| async move {
461 match line {
462 Ok(line) => {
463 let line = line.strip_prefix("data: ")?;
464 if line == "[DONE]" {
465 None
466 } else {
467 match serde_json::from_str(line) {
468 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
469 Ok(ResponseStreamResult::Err { error }) => {
470 Some(Err(anyhow!(error)))
471 }
472 Err(error) => {
473 log::error!(
474 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
475 Response: `{}`",
476 error,
477 line,
478 );
479 Some(Err(anyhow!(error)))
480 }
481 }
482 }
483 }
484 Err(error) => Some(Err(anyhow!(error))),
485 }
486 })
487 .boxed())
488 } else {
489 let mut body = String::new();
490 response.body_mut().read_to_string(&mut body).await?;
491
492 #[derive(Deserialize)]
493 struct OpenAiResponse {
494 error: OpenAiError,
495 }
496
497 #[derive(Deserialize)]
498 struct OpenAiError {
499 message: String,
500 }
501
502 match serde_json::from_str::<OpenAiResponse>(&body) {
503 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
504 "API request to {} failed: {}",
505 api_url,
506 response.error.message,
507 )),
508
509 _ => anyhow::bail!(
510 "API request to {} failed with status {}: {}",
511 api_url,
512 response.status(),
513 body,
514 ),
515 }
516 }
517}
518
519#[derive(Copy, Clone, Serialize, Deserialize)]
520pub enum OpenAiEmbeddingModel {
521 #[serde(rename = "text-embedding-3-small")]
522 TextEmbedding3Small,
523 #[serde(rename = "text-embedding-3-large")]
524 TextEmbedding3Large,
525}
526
527#[derive(Serialize)]
528struct OpenAiEmbeddingRequest<'a> {
529 model: OpenAiEmbeddingModel,
530 input: Vec<&'a str>,
531}
532
533#[derive(Deserialize)]
534pub struct OpenAiEmbeddingResponse {
535 pub data: Vec<OpenAiEmbedding>,
536}
537
538#[derive(Deserialize)]
539pub struct OpenAiEmbedding {
540 pub embedding: Vec<f32>,
541}
542
543pub fn embed<'a>(
544 client: &dyn HttpClient,
545 api_url: &str,
546 api_key: &str,
547 model: OpenAiEmbeddingModel,
548 texts: impl IntoIterator<Item = &'a str>,
549) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
550 let uri = format!("{api_url}/embeddings");
551
552 let request = OpenAiEmbeddingRequest {
553 model,
554 input: texts.into_iter().collect(),
555 };
556 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
557 let request = HttpRequest::builder()
558 .method(Method::POST)
559 .uri(uri)
560 .header("Content-Type", "application/json")
561 .header("Authorization", format!("Bearer {}", api_key))
562 .body(body)
563 .map(|request| client.send(request));
564
565 async move {
566 let mut response = request?.await?;
567 let mut body = String::new();
568 response.body_mut().read_to_string(&mut body).await?;
569
570 anyhow::ensure!(
571 response.status().is_success(),
572 "error during embedding, status: {:?}, body: {:?}",
573 response.status(),
574 body
575 );
576 let response: OpenAiEmbeddingResponse =
577 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
578 Ok(response)
579 }
580}