1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, future::Future};
7use strum::EnumIter;
8
9pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
10
11fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
12 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
13}
14
15#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
16#[serde(rename_all = "lowercase")]
17pub enum Role {
18 User,
19 Assistant,
20 System,
21 Tool,
22}
23
24impl TryFrom<String> for Role {
25 type Error = anyhow::Error;
26
27 fn try_from(value: String) -> Result<Self> {
28 match value.as_str() {
29 "user" => Ok(Self::User),
30 "assistant" => Ok(Self::Assistant),
31 "system" => Ok(Self::System),
32 "tool" => Ok(Self::Tool),
33 _ => anyhow::bail!("invalid role '{value}'"),
34 }
35 }
36}
37
38impl From<Role> for String {
39 fn from(val: Role) -> Self {
40 match val {
41 Role::User => "user".to_owned(),
42 Role::Assistant => "assistant".to_owned(),
43 Role::System => "system".to_owned(),
44 Role::Tool => "tool".to_owned(),
45 }
46 }
47}
48
49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
50#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
51pub enum Model {
52 #[serde(rename = "gpt-3.5-turbo")]
53 ThreePointFiveTurbo,
54 #[serde(rename = "gpt-4")]
55 Four,
56 #[serde(rename = "gpt-4-turbo")]
57 FourTurbo,
58 #[serde(rename = "gpt-4o")]
59 #[default]
60 FourOmni,
61 #[serde(rename = "gpt-4o-mini")]
62 FourOmniMini,
63 #[serde(rename = "gpt-4.1")]
64 FourPointOne,
65 #[serde(rename = "gpt-4.1-mini")]
66 FourPointOneMini,
67 #[serde(rename = "gpt-4.1-nano")]
68 FourPointOneNano,
69 #[serde(rename = "o1")]
70 O1,
71 #[serde(rename = "o3-mini")]
72 O3Mini,
73 #[serde(rename = "o3")]
74 O3,
75 #[serde(rename = "o4-mini")]
76 O4Mini,
77 #[serde(rename = "gpt-5")]
78 Five,
79 #[serde(rename = "gpt-5-mini")]
80 FiveMini,
81 #[serde(rename = "gpt-5-nano")]
82 FiveNano,
83
84 #[serde(rename = "custom")]
85 Custom {
86 name: String,
87 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
88 display_name: Option<String>,
89 max_tokens: u64,
90 max_output_tokens: Option<u64>,
91 max_completion_tokens: Option<u64>,
92 reasoning_effort: Option<ReasoningEffort>,
93 },
94}
95
96impl Model {
97 pub fn default_fast() -> Self {
98 // TODO: Replace with FiveMini since all other models are deprecated
99 Self::FourPointOneMini
100 }
101
102 pub fn from_id(id: &str) -> Result<Self> {
103 match id {
104 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
105 "gpt-4" => Ok(Self::Four),
106 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
107 "gpt-4o" => Ok(Self::FourOmni),
108 "gpt-4o-mini" => Ok(Self::FourOmniMini),
109 "gpt-4.1" => Ok(Self::FourPointOne),
110 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
111 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
112 "o1" => Ok(Self::O1),
113 "o3-mini" => Ok(Self::O3Mini),
114 "o3" => Ok(Self::O3),
115 "o4-mini" => Ok(Self::O4Mini),
116 "gpt-5" => Ok(Self::Five),
117 "gpt-5-mini" => Ok(Self::FiveMini),
118 "gpt-5-nano" => Ok(Self::FiveNano),
119 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
120 }
121 }
122
123 pub fn id(&self) -> &str {
124 match self {
125 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
126 Self::Four => "gpt-4",
127 Self::FourTurbo => "gpt-4-turbo",
128 Self::FourOmni => "gpt-4o",
129 Self::FourOmniMini => "gpt-4o-mini",
130 Self::FourPointOne => "gpt-4.1",
131 Self::FourPointOneMini => "gpt-4.1-mini",
132 Self::FourPointOneNano => "gpt-4.1-nano",
133 Self::O1 => "o1",
134 Self::O3Mini => "o3-mini",
135 Self::O3 => "o3",
136 Self::O4Mini => "o4-mini",
137 Self::Five => "gpt-5",
138 Self::FiveMini => "gpt-5-mini",
139 Self::FiveNano => "gpt-5-nano",
140 Self::Custom { name, .. } => name,
141 }
142 }
143
144 pub fn display_name(&self) -> &str {
145 match self {
146 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
147 Self::Four => "gpt-4",
148 Self::FourTurbo => "gpt-4-turbo",
149 Self::FourOmni => "gpt-4o",
150 Self::FourOmniMini => "gpt-4o-mini",
151 Self::FourPointOne => "gpt-4.1",
152 Self::FourPointOneMini => "gpt-4.1-mini",
153 Self::FourPointOneNano => "gpt-4.1-nano",
154 Self::O1 => "o1",
155 Self::O3Mini => "o3-mini",
156 Self::O3 => "o3",
157 Self::O4Mini => "o4-mini",
158 Self::Five => "gpt-5",
159 Self::FiveMini => "gpt-5-mini",
160 Self::FiveNano => "gpt-5-nano",
161 Self::Custom {
162 name, display_name, ..
163 } => display_name.as_ref().unwrap_or(name),
164 }
165 }
166
167 pub fn max_token_count(&self) -> u64 {
168 match self {
169 Self::ThreePointFiveTurbo => 16_385,
170 Self::Four => 8_192,
171 Self::FourTurbo => 128_000,
172 Self::FourOmni => 128_000,
173 Self::FourOmniMini => 128_000,
174 Self::FourPointOne => 1_047_576,
175 Self::FourPointOneMini => 1_047_576,
176 Self::FourPointOneNano => 1_047_576,
177 Self::O1 => 200_000,
178 Self::O3Mini => 200_000,
179 Self::O3 => 200_000,
180 Self::O4Mini => 200_000,
181 Self::Five => 272_000,
182 Self::FiveMini => 272_000,
183 Self::FiveNano => 272_000,
184 Self::Custom { max_tokens, .. } => *max_tokens,
185 }
186 }
187
188 pub fn max_output_tokens(&self) -> Option<u64> {
189 match self {
190 Self::Custom {
191 max_output_tokens, ..
192 } => *max_output_tokens,
193 Self::ThreePointFiveTurbo => Some(4_096),
194 Self::Four => Some(8_192),
195 Self::FourTurbo => Some(4_096),
196 Self::FourOmni => Some(16_384),
197 Self::FourOmniMini => Some(16_384),
198 Self::FourPointOne => Some(32_768),
199 Self::FourPointOneMini => Some(32_768),
200 Self::FourPointOneNano => Some(32_768),
201 Self::O1 => Some(100_000),
202 Self::O3Mini => Some(100_000),
203 Self::O3 => Some(100_000),
204 Self::O4Mini => Some(100_000),
205 Self::Five => Some(128_000),
206 Self::FiveMini => Some(128_000),
207 Self::FiveNano => Some(128_000),
208 }
209 }
210
211 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
212 match self {
213 Self::Custom {
214 reasoning_effort, ..
215 } => reasoning_effort.to_owned(),
216 _ => None,
217 }
218 }
219
220 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
221 ///
222 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
223 pub fn supports_parallel_tool_calls(&self) -> bool {
224 match self {
225 Self::ThreePointFiveTurbo
226 | Self::Four
227 | Self::FourTurbo
228 | Self::FourOmni
229 | Self::FourOmniMini
230 | Self::FourPointOne
231 | Self::FourPointOneMini
232 | Self::FourPointOneNano
233 | Self::Five
234 | Self::FiveMini
235 | Self::FiveNano => true,
236 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
237 }
238 }
239
240 /// Returns whether the given model supports the `prompt_cache_key` parameter.
241 ///
242 /// If the model does not support the parameter, do not pass it up.
243 pub fn supports_prompt_cache_key(&self) -> bool {
244 true
245 }
246}
247
248#[derive(Debug, Serialize, Deserialize)]
249pub struct Request {
250 pub model: String,
251 pub messages: Vec<RequestMessage>,
252 pub stream: bool,
253 #[serde(default, skip_serializing_if = "Option::is_none")]
254 pub max_completion_tokens: Option<u64>,
255 #[serde(default, skip_serializing_if = "Vec::is_empty")]
256 pub stop: Vec<String>,
257 pub temperature: f32,
258 #[serde(default, skip_serializing_if = "Option::is_none")]
259 pub tool_choice: Option<ToolChoice>,
260 /// Whether to enable parallel function calling during tool use.
261 #[serde(default, skip_serializing_if = "Option::is_none")]
262 pub parallel_tool_calls: Option<bool>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
264 pub tools: Vec<ToolDefinition>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub prompt_cache_key: Option<String>,
267 #[serde(default, skip_serializing_if = "Option::is_none")]
268 pub reasoning_effort: Option<ReasoningEffort>,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272#[serde(untagged)]
273pub enum ToolChoice {
274 Auto,
275 Required,
276 None,
277 Other(ToolDefinition),
278}
279
280#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
281#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
282#[serde(rename_all = "lowercase")]
283pub enum ReasoningEffort {
284 Minimal,
285 Low,
286 Medium,
287 High,
288}
289
290#[derive(Clone, Deserialize, Serialize, Debug)]
291#[serde(tag = "type", rename_all = "snake_case")]
292pub enum ToolDefinition {
293 #[allow(dead_code)]
294 Function { function: FunctionDefinition },
295}
296
297#[derive(Clone, Debug, Serialize, Deserialize)]
298pub struct FunctionDefinition {
299 pub name: String,
300 pub description: Option<String>,
301 pub parameters: Option<Value>,
302}
303
304#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
305#[serde(tag = "role", rename_all = "lowercase")]
306pub enum RequestMessage {
307 Assistant {
308 content: Option<MessageContent>,
309 #[serde(default, skip_serializing_if = "Vec::is_empty")]
310 tool_calls: Vec<ToolCall>,
311 },
312 User {
313 content: MessageContent,
314 },
315 System {
316 content: MessageContent,
317 },
318 Tool {
319 content: MessageContent,
320 tool_call_id: String,
321 },
322}
323
324#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
325#[serde(untagged)]
326pub enum MessageContent {
327 Plain(String),
328 Multipart(Vec<MessagePart>),
329}
330
331impl MessageContent {
332 pub fn empty() -> Self {
333 MessageContent::Multipart(vec![])
334 }
335
336 pub fn push_part(&mut self, part: MessagePart) {
337 match self {
338 MessageContent::Plain(text) => {
339 *self =
340 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
341 }
342 MessageContent::Multipart(parts) if parts.is_empty() => match part {
343 MessagePart::Text { text } => *self = MessageContent::Plain(text),
344 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
345 },
346 MessageContent::Multipart(parts) => parts.push(part),
347 }
348 }
349}
350
351impl From<Vec<MessagePart>> for MessageContent {
352 fn from(mut parts: Vec<MessagePart>) -> Self {
353 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
354 MessageContent::Plain(std::mem::take(text))
355 } else {
356 MessageContent::Multipart(parts)
357 }
358 }
359}
360
361#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
362#[serde(tag = "type")]
363pub enum MessagePart {
364 #[serde(rename = "text")]
365 Text { text: String },
366 #[serde(rename = "image_url")]
367 Image { image_url: ImageUrl },
368}
369
370#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
371pub struct ImageUrl {
372 pub url: String,
373 #[serde(skip_serializing_if = "Option::is_none")]
374 pub detail: Option<String>,
375}
376
377#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
378pub struct ToolCall {
379 pub id: String,
380 #[serde(flatten)]
381 pub content: ToolCallContent,
382}
383
384#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
385#[serde(tag = "type", rename_all = "lowercase")]
386pub enum ToolCallContent {
387 Function { function: FunctionContent },
388}
389
390#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
391pub struct FunctionContent {
392 pub name: String,
393 pub arguments: String,
394}
395
396#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
397pub struct ResponseMessageDelta {
398 pub role: Option<Role>,
399 pub content: Option<String>,
400 #[serde(default, skip_serializing_if = "is_none_or_empty")]
401 pub tool_calls: Option<Vec<ToolCallChunk>>,
402}
403
404#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
405pub struct ToolCallChunk {
406 pub index: usize,
407 pub id: Option<String>,
408
409 // There is also an optional `type` field that would determine if a
410 // function is there. Sometimes this streams in with the `function` before
411 // it streams in the `type`
412 pub function: Option<FunctionChunk>,
413}
414
415#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
416pub struct FunctionChunk {
417 pub name: Option<String>,
418 pub arguments: Option<String>,
419}
420
421#[derive(Serialize, Deserialize, Debug)]
422pub struct Usage {
423 pub prompt_tokens: u64,
424 pub completion_tokens: u64,
425 pub total_tokens: u64,
426}
427
428#[derive(Serialize, Deserialize, Debug)]
429pub struct ChoiceDelta {
430 pub index: u32,
431 pub delta: ResponseMessageDelta,
432 pub finish_reason: Option<String>,
433}
434
435#[derive(Serialize, Deserialize, Debug)]
436pub struct OpenAiError {
437 message: String,
438}
439
440#[derive(Serialize, Deserialize, Debug)]
441#[serde(untagged)]
442pub enum ResponseStreamResult {
443 Ok(ResponseStreamEvent),
444 Err { error: OpenAiError },
445}
446
447#[derive(Serialize, Deserialize, Debug)]
448pub struct ResponseStreamEvent {
449 pub choices: Vec<ChoiceDelta>,
450 pub usage: Option<Usage>,
451}
452
453pub async fn stream_completion(
454 client: &dyn HttpClient,
455 api_url: &str,
456 api_key: &str,
457 request: Request,
458) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
459 let uri = format!("{api_url}/chat/completions");
460 let request_builder = HttpRequest::builder()
461 .method(Method::POST)
462 .uri(uri)
463 .header("Content-Type", "application/json")
464 .header("Authorization", format!("Bearer {}", api_key.trim()));
465
466 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
467 let mut response = client.send(request).await?;
468 if response.status().is_success() {
469 let reader = BufReader::new(response.into_body());
470 Ok(reader
471 .lines()
472 .filter_map(|line| async move {
473 match line {
474 Ok(line) => {
475 let line = line.strip_prefix("data: ")?;
476 if line == "[DONE]" {
477 None
478 } else {
479 match serde_json::from_str(line) {
480 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
481 Ok(ResponseStreamResult::Err { error }) => {
482 Some(Err(anyhow!(error.message)))
483 }
484 Err(error) => {
485 log::error!(
486 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
487 Response: `{}`",
488 error,
489 line,
490 );
491 Some(Err(anyhow!(error)))
492 }
493 }
494 }
495 }
496 Err(error) => Some(Err(anyhow!(error))),
497 }
498 })
499 .boxed())
500 } else {
501 let mut body = String::new();
502 response.body_mut().read_to_string(&mut body).await?;
503
504 #[derive(Deserialize)]
505 struct OpenAiResponse {
506 error: OpenAiError,
507 }
508
509 match serde_json::from_str::<OpenAiResponse>(&body) {
510 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
511 "API request to {} failed: {}",
512 api_url,
513 response.error.message,
514 )),
515
516 _ => anyhow::bail!(
517 "API request to {} failed with status {}: {}",
518 api_url,
519 response.status(),
520 body,
521 ),
522 }
523 }
524}
525
526#[derive(Copy, Clone, Serialize, Deserialize)]
527pub enum OpenAiEmbeddingModel {
528 #[serde(rename = "text-embedding-3-small")]
529 TextEmbedding3Small,
530 #[serde(rename = "text-embedding-3-large")]
531 TextEmbedding3Large,
532}
533
534#[derive(Serialize)]
535struct OpenAiEmbeddingRequest<'a> {
536 model: OpenAiEmbeddingModel,
537 input: Vec<&'a str>,
538}
539
540#[derive(Deserialize)]
541pub struct OpenAiEmbeddingResponse {
542 pub data: Vec<OpenAiEmbedding>,
543}
544
545#[derive(Deserialize)]
546pub struct OpenAiEmbedding {
547 pub embedding: Vec<f32>,
548}
549
550pub fn embed<'a>(
551 client: &dyn HttpClient,
552 api_url: &str,
553 api_key: &str,
554 model: OpenAiEmbeddingModel,
555 texts: impl IntoIterator<Item = &'a str>,
556) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
557 let uri = format!("{api_url}/embeddings");
558
559 let request = OpenAiEmbeddingRequest {
560 model,
561 input: texts.into_iter().collect(),
562 };
563 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
564 let request = HttpRequest::builder()
565 .method(Method::POST)
566 .uri(uri)
567 .header("Content-Type", "application/json")
568 .header("Authorization", format!("Bearer {}", api_key.trim()))
569 .body(body)
570 .map(|request| client.send(request));
571
572 async move {
573 let mut response = request?.await?;
574 let mut body = String::new();
575 response.body_mut().read_to_string(&mut body).await?;
576
577 anyhow::ensure!(
578 response.status().is_success(),
579 "error during embedding, status: {:?}, body: {:?}",
580 response.status(),
581 body
582 );
583 let response: OpenAiEmbeddingResponse =
584 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
585 Ok(response)
586 }
587}