1pub mod batches;
2pub mod responses;
3
4use anyhow::{Context as _, Result, anyhow};
5use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
6use http_client::{
7 AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
8 http::{HeaderMap, HeaderValue},
9};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12pub use settings::OpenAiReasoningEffort as ReasoningEffort;
13use std::{convert::TryFrom, future::Future};
14use strum::EnumIter;
15use thiserror::Error;
16
17pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
18
19fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
20 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
21}
22
23#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
24#[serde(rename_all = "lowercase")]
25pub enum Role {
26 User,
27 Assistant,
28 System,
29 Tool,
30}
31
32impl TryFrom<String> for Role {
33 type Error = anyhow::Error;
34
35 fn try_from(value: String) -> Result<Self> {
36 match value.as_str() {
37 "user" => Ok(Self::User),
38 "assistant" => Ok(Self::Assistant),
39 "system" => Ok(Self::System),
40 "tool" => Ok(Self::Tool),
41 _ => anyhow::bail!("invalid role '{value}'"),
42 }
43 }
44}
45
46impl From<Role> for String {
47 fn from(val: Role) -> Self {
48 match val {
49 Role::User => "user".to_owned(),
50 Role::Assistant => "assistant".to_owned(),
51 Role::System => "system".to_owned(),
52 Role::Tool => "tool".to_owned(),
53 }
54 }
55}
56
57#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
58#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
59pub enum Model {
60 #[serde(rename = "gpt-3.5-turbo")]
61 ThreePointFiveTurbo,
62 #[serde(rename = "gpt-4")]
63 Four,
64 #[serde(rename = "gpt-4-turbo")]
65 FourTurbo,
66 #[serde(rename = "gpt-4o-mini")]
67 FourOmniMini,
68 #[serde(rename = "gpt-4.1-nano")]
69 FourPointOneNano,
70 #[serde(rename = "o1")]
71 O1,
72 #[serde(rename = "o3-mini")]
73 O3Mini,
74 #[serde(rename = "o3")]
75 O3,
76 #[serde(rename = "gpt-5")]
77 Five,
78 #[serde(rename = "gpt-5-codex")]
79 FiveCodex,
80 #[serde(rename = "gpt-5-mini")]
81 #[default]
82 FiveMini,
83 #[serde(rename = "gpt-5-nano")]
84 FiveNano,
85 #[serde(rename = "gpt-5.1")]
86 FivePointOne,
87 #[serde(rename = "gpt-5.2")]
88 FivePointTwo,
89 #[serde(rename = "gpt-5.2-codex")]
90 FivePointTwoCodex,
91 #[serde(rename = "custom")]
92 Custom {
93 name: String,
94 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
95 display_name: Option<String>,
96 max_tokens: u64,
97 max_output_tokens: Option<u64>,
98 max_completion_tokens: Option<u64>,
99 reasoning_effort: Option<ReasoningEffort>,
100 #[serde(default = "default_supports_chat_completions")]
101 supports_chat_completions: bool,
102 },
103}
104
105const fn default_supports_chat_completions() -> bool {
106 true
107}
108
109impl Model {
110 pub fn default_fast() -> Self {
111 Self::FiveMini
112 }
113
114 pub fn from_id(id: &str) -> Result<Self> {
115 match id {
116 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
117 "gpt-4" => Ok(Self::Four),
118 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
119 "gpt-4o-mini" => Ok(Self::FourOmniMini),
120 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
121 "o1" => Ok(Self::O1),
122 "o3-mini" => Ok(Self::O3Mini),
123 "o3" => Ok(Self::O3),
124 "gpt-5" => Ok(Self::Five),
125 "gpt-5-codex" => Ok(Self::FiveCodex),
126 "gpt-5-mini" => Ok(Self::FiveMini),
127 "gpt-5-nano" => Ok(Self::FiveNano),
128 "gpt-5.1" => Ok(Self::FivePointOne),
129 "gpt-5.2" => Ok(Self::FivePointTwo),
130 "gpt-5.2-codex" => Ok(Self::FivePointTwoCodex),
131 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
132 }
133 }
134
135 pub fn id(&self) -> &str {
136 match self {
137 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
138 Self::Four => "gpt-4",
139 Self::FourTurbo => "gpt-4-turbo",
140 Self::FourOmniMini => "gpt-4o-mini",
141 Self::FourPointOneNano => "gpt-4.1-nano",
142 Self::O1 => "o1",
143 Self::O3Mini => "o3-mini",
144 Self::O3 => "o3",
145 Self::Five => "gpt-5",
146 Self::FiveCodex => "gpt-5-codex",
147 Self::FiveMini => "gpt-5-mini",
148 Self::FiveNano => "gpt-5-nano",
149 Self::FivePointOne => "gpt-5.1",
150 Self::FivePointTwo => "gpt-5.2",
151 Self::FivePointTwoCodex => "gpt-5.2-codex",
152 Self::Custom { name, .. } => name,
153 }
154 }
155
156 pub fn display_name(&self) -> &str {
157 match self {
158 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
159 Self::Four => "gpt-4",
160 Self::FourTurbo => "gpt-4-turbo",
161 Self::FourOmniMini => "gpt-4o-mini",
162 Self::FourPointOneNano => "gpt-4.1-nano",
163 Self::O1 => "o1",
164 Self::O3Mini => "o3-mini",
165 Self::O3 => "o3",
166 Self::Five => "gpt-5",
167 Self::FiveCodex => "gpt-5-codex",
168 Self::FiveMini => "gpt-5-mini",
169 Self::FiveNano => "gpt-5-nano",
170 Self::FivePointOne => "gpt-5.1",
171 Self::FivePointTwo => "gpt-5.2",
172 Self::FivePointTwoCodex => "gpt-5.2-codex",
173 Self::Custom { display_name, .. } => display_name.as_deref().unwrap_or(&self.id()),
174 }
175 }
176
177 pub fn max_token_count(&self) -> u64 {
178 match self {
179 Self::ThreePointFiveTurbo => 16_385,
180 Self::Four => 8_192,
181 Self::FourTurbo => 128_000,
182 Self::FourOmniMini => 128_000,
183 Self::FourPointOneNano => 1_047_576,
184 Self::O1 => 200_000,
185 Self::O3Mini => 200_000,
186 Self::O3 => 200_000,
187 Self::Five => 272_000,
188 Self::FiveCodex => 272_000,
189 Self::FiveMini => 272_000,
190 Self::FiveNano => 272_000,
191 Self::FivePointOne => 400_000,
192 Self::FivePointTwo => 400_000,
193 Self::FivePointTwoCodex => 400_000,
194 Self::Custom { max_tokens, .. } => *max_tokens,
195 }
196 }
197
198 pub fn max_output_tokens(&self) -> Option<u64> {
199 match self {
200 Self::Custom {
201 max_output_tokens, ..
202 } => *max_output_tokens,
203 Self::ThreePointFiveTurbo => Some(4_096),
204 Self::Four => Some(8_192),
205 Self::FourTurbo => Some(4_096),
206 Self::FourOmniMini => Some(16_384),
207 Self::FourPointOneNano => Some(32_768),
208 Self::O1 => Some(100_000),
209 Self::O3Mini => Some(100_000),
210 Self::O3 => Some(100_000),
211 Self::Five => Some(128_000),
212 Self::FiveCodex => Some(128_000),
213 Self::FiveMini => Some(128_000),
214 Self::FiveNano => Some(128_000),
215 Self::FivePointOne => Some(128_000),
216 Self::FivePointTwo => Some(128_000),
217 Self::FivePointTwoCodex => Some(128_000),
218 }
219 }
220
221 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
222 match self {
223 Self::Custom {
224 reasoning_effort, ..
225 } => reasoning_effort.to_owned(),
226 _ => None,
227 }
228 }
229
230 pub fn supports_chat_completions(&self) -> bool {
231 match self {
232 Self::Custom {
233 supports_chat_completions,
234 ..
235 } => *supports_chat_completions,
236 Self::FiveCodex | Self::FivePointTwoCodex => false,
237 _ => true,
238 }
239 }
240
241 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
242 ///
243 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
244 pub fn supports_parallel_tool_calls(&self) -> bool {
245 match self {
246 Self::ThreePointFiveTurbo
247 | Self::Four
248 | Self::FourTurbo
249 | Self::FourOmniMini
250 | Self::FourPointOneNano
251 | Self::Five
252 | Self::FiveCodex
253 | Self::FiveMini
254 | Self::FivePointOne
255 | Self::FivePointTwo
256 | Self::FivePointTwoCodex
257 | Self::FiveNano => true,
258 Self::O1 | Self::O3 | Self::O3Mini | Model::Custom { .. } => false,
259 }
260 }
261
262 /// Returns whether the given model supports the `prompt_cache_key` parameter.
263 ///
264 /// If the model does not support the parameter, do not pass it up.
265 pub fn supports_prompt_cache_key(&self) -> bool {
266 true
267 }
268}
269
270#[derive(Debug, Serialize, Deserialize)]
271pub struct Request {
272 pub model: String,
273 pub messages: Vec<RequestMessage>,
274 pub stream: bool,
275 #[serde(default, skip_serializing_if = "Option::is_none")]
276 pub max_completion_tokens: Option<u64>,
277 #[serde(default, skip_serializing_if = "Vec::is_empty")]
278 pub stop: Vec<String>,
279 #[serde(default, skip_serializing_if = "Option::is_none")]
280 pub temperature: Option<f32>,
281 #[serde(default, skip_serializing_if = "Option::is_none")]
282 pub tool_choice: Option<ToolChoice>,
283 /// Whether to enable parallel function calling during tool use.
284 #[serde(default, skip_serializing_if = "Option::is_none")]
285 pub parallel_tool_calls: Option<bool>,
286 #[serde(default, skip_serializing_if = "Vec::is_empty")]
287 pub tools: Vec<ToolDefinition>,
288 #[serde(default, skip_serializing_if = "Option::is_none")]
289 pub prompt_cache_key: Option<String>,
290 #[serde(default, skip_serializing_if = "Option::is_none")]
291 pub reasoning_effort: Option<ReasoningEffort>,
292}
293
294#[derive(Debug, Serialize, Deserialize)]
295#[serde(rename_all = "lowercase")]
296pub enum ToolChoice {
297 Auto,
298 Required,
299 None,
300 #[serde(untagged)]
301 Other(ToolDefinition),
302}
303
304#[derive(Clone, Deserialize, Serialize, Debug)]
305#[serde(tag = "type", rename_all = "snake_case")]
306pub enum ToolDefinition {
307 #[allow(dead_code)]
308 Function { function: FunctionDefinition },
309}
310
311#[derive(Clone, Debug, Serialize, Deserialize)]
312pub struct FunctionDefinition {
313 pub name: String,
314 pub description: Option<String>,
315 pub parameters: Option<Value>,
316}
317
318#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
319#[serde(tag = "role", rename_all = "lowercase")]
320pub enum RequestMessage {
321 Assistant {
322 content: Option<MessageContent>,
323 #[serde(default, skip_serializing_if = "Vec::is_empty")]
324 tool_calls: Vec<ToolCall>,
325 },
326 User {
327 content: MessageContent,
328 },
329 System {
330 content: MessageContent,
331 },
332 Tool {
333 content: MessageContent,
334 tool_call_id: String,
335 },
336}
337
338#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
339#[serde(untagged)]
340pub enum MessageContent {
341 Plain(String),
342 Multipart(Vec<MessagePart>),
343}
344
345impl MessageContent {
346 pub fn empty() -> Self {
347 MessageContent::Multipart(vec![])
348 }
349
350 pub fn push_part(&mut self, part: MessagePart) {
351 match self {
352 MessageContent::Plain(text) => {
353 *self =
354 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
355 }
356 MessageContent::Multipart(parts) if parts.is_empty() => match part {
357 MessagePart::Text { text } => *self = MessageContent::Plain(text),
358 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
359 },
360 MessageContent::Multipart(parts) => parts.push(part),
361 }
362 }
363}
364
365impl From<Vec<MessagePart>> for MessageContent {
366 fn from(mut parts: Vec<MessagePart>) -> Self {
367 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
368 MessageContent::Plain(std::mem::take(text))
369 } else {
370 MessageContent::Multipart(parts)
371 }
372 }
373}
374
375#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
376#[serde(tag = "type")]
377pub enum MessagePart {
378 #[serde(rename = "text")]
379 Text { text: String },
380 #[serde(rename = "image_url")]
381 Image { image_url: ImageUrl },
382}
383
384#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
385pub struct ImageUrl {
386 pub url: String,
387 #[serde(skip_serializing_if = "Option::is_none")]
388 pub detail: Option<String>,
389}
390
391#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
392pub struct ToolCall {
393 pub id: String,
394 #[serde(flatten)]
395 pub content: ToolCallContent,
396}
397
398#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
399#[serde(tag = "type", rename_all = "lowercase")]
400pub enum ToolCallContent {
401 Function { function: FunctionContent },
402}
403
404#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
405pub struct FunctionContent {
406 pub name: String,
407 pub arguments: String,
408}
409
410#[derive(Clone, Serialize, Deserialize, Debug)]
411pub struct Response {
412 pub id: String,
413 pub object: String,
414 pub created: u64,
415 pub model: String,
416 pub choices: Vec<Choice>,
417 pub usage: Usage,
418}
419
420#[derive(Clone, Serialize, Deserialize, Debug)]
421pub struct Choice {
422 pub index: u32,
423 pub message: RequestMessage,
424 pub finish_reason: Option<String>,
425}
426
427#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
428pub struct ResponseMessageDelta {
429 pub role: Option<Role>,
430 pub content: Option<String>,
431 #[serde(default, skip_serializing_if = "is_none_or_empty")]
432 pub tool_calls: Option<Vec<ToolCallChunk>>,
433 #[serde(default, skip_serializing_if = "is_none_or_empty")]
434 pub reasoning_content: Option<String>,
435}
436
437#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
438pub struct ToolCallChunk {
439 pub index: usize,
440 pub id: Option<String>,
441
442 // There is also an optional `type` field that would determine if a
443 // function is there. Sometimes this streams in with the `function` before
444 // it streams in the `type`
445 pub function: Option<FunctionChunk>,
446}
447
448#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
449pub struct FunctionChunk {
450 pub name: Option<String>,
451 pub arguments: Option<String>,
452}
453
454#[derive(Clone, Serialize, Deserialize, Debug)]
455pub struct Usage {
456 pub prompt_tokens: u64,
457 pub completion_tokens: u64,
458 pub total_tokens: u64,
459}
460
461#[derive(Serialize, Deserialize, Debug)]
462pub struct ChoiceDelta {
463 pub index: u32,
464 pub delta: Option<ResponseMessageDelta>,
465 pub finish_reason: Option<String>,
466}
467
468#[derive(Error, Debug)]
469pub enum RequestError {
470 #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
471 HttpResponseError {
472 provider: String,
473 status_code: StatusCode,
474 body: String,
475 headers: HeaderMap<HeaderValue>,
476 },
477 #[error(transparent)]
478 Other(#[from] anyhow::Error),
479}
480
481#[derive(Serialize, Deserialize, Debug)]
482pub struct ResponseStreamError {
483 message: String,
484}
485
486#[derive(Serialize, Deserialize, Debug)]
487#[serde(untagged)]
488pub enum ResponseStreamResult {
489 Ok(ResponseStreamEvent),
490 Err { error: ResponseStreamError },
491}
492
493#[derive(Serialize, Deserialize, Debug)]
494pub struct ResponseStreamEvent {
495 pub choices: Vec<ChoiceDelta>,
496 pub usage: Option<Usage>,
497}
498
499pub async fn non_streaming_completion(
500 client: &dyn HttpClient,
501 api_url: &str,
502 api_key: &str,
503 request: Request,
504) -> Result<Response, RequestError> {
505 let uri = format!("{api_url}/chat/completions");
506 let request_builder = HttpRequest::builder()
507 .method(Method::POST)
508 .uri(uri)
509 .header("Content-Type", "application/json")
510 .header("Authorization", format!("Bearer {}", api_key.trim()));
511
512 let request = request_builder
513 .body(AsyncBody::from(
514 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
515 ))
516 .map_err(|e| RequestError::Other(e.into()))?;
517
518 let mut response = client.send(request).await?;
519 if response.status().is_success() {
520 let mut body = String::new();
521 response
522 .body_mut()
523 .read_to_string(&mut body)
524 .await
525 .map_err(|e| RequestError::Other(e.into()))?;
526
527 serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
528 } else {
529 let mut body = String::new();
530 response
531 .body_mut()
532 .read_to_string(&mut body)
533 .await
534 .map_err(|e| RequestError::Other(e.into()))?;
535
536 Err(RequestError::HttpResponseError {
537 provider: "openai".to_owned(),
538 status_code: response.status(),
539 body,
540 headers: response.headers().clone(),
541 })
542 }
543}
544
545pub async fn stream_completion(
546 client: &dyn HttpClient,
547 provider_name: &str,
548 api_url: &str,
549 api_key: &str,
550 request: Request,
551) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
552 let uri = format!("{api_url}/chat/completions");
553 let request_builder = HttpRequest::builder()
554 .method(Method::POST)
555 .uri(uri)
556 .header("Content-Type", "application/json")
557 .header("Authorization", format!("Bearer {}", api_key.trim()));
558
559 let request = request_builder
560 .body(AsyncBody::from(
561 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
562 ))
563 .map_err(|e| RequestError::Other(e.into()))?;
564
565 let mut response = client.send(request).await?;
566 if response.status().is_success() {
567 let reader = BufReader::new(response.into_body());
568 Ok(reader
569 .lines()
570 .filter_map(|line| async move {
571 match line {
572 Ok(line) => {
573 let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
574 if line == "[DONE]" {
575 None
576 } else {
577 match serde_json::from_str(line) {
578 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
579 Ok(ResponseStreamResult::Err { error }) => {
580 Some(Err(anyhow!(error.message)))
581 }
582 Err(error) => {
583 log::error!(
584 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
585 Response: `{}`",
586 error,
587 line,
588 );
589 Some(Err(anyhow!(error)))
590 }
591 }
592 }
593 }
594 Err(error) => Some(Err(anyhow!(error))),
595 }
596 })
597 .boxed())
598 } else {
599 let mut body = String::new();
600 response
601 .body_mut()
602 .read_to_string(&mut body)
603 .await
604 .map_err(|e| RequestError::Other(e.into()))?;
605
606 Err(RequestError::HttpResponseError {
607 provider: provider_name.to_owned(),
608 status_code: response.status(),
609 body,
610 headers: response.headers().clone(),
611 })
612 }
613}
614
615#[derive(Copy, Clone, Serialize, Deserialize)]
616pub enum OpenAiEmbeddingModel {
617 #[serde(rename = "text-embedding-3-small")]
618 TextEmbedding3Small,
619 #[serde(rename = "text-embedding-3-large")]
620 TextEmbedding3Large,
621}
622
623#[derive(Serialize)]
624struct OpenAiEmbeddingRequest<'a> {
625 model: OpenAiEmbeddingModel,
626 input: Vec<&'a str>,
627}
628
629#[derive(Deserialize)]
630pub struct OpenAiEmbeddingResponse {
631 pub data: Vec<OpenAiEmbedding>,
632}
633
634#[derive(Deserialize)]
635pub struct OpenAiEmbedding {
636 pub embedding: Vec<f32>,
637}
638
639pub fn embed<'a>(
640 client: &dyn HttpClient,
641 api_url: &str,
642 api_key: &str,
643 model: OpenAiEmbeddingModel,
644 texts: impl IntoIterator<Item = &'a str>,
645) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
646 let uri = format!("{api_url}/embeddings");
647
648 let request = OpenAiEmbeddingRequest {
649 model,
650 input: texts.into_iter().collect(),
651 };
652 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
653 let request = HttpRequest::builder()
654 .method(Method::POST)
655 .uri(uri)
656 .header("Content-Type", "application/json")
657 .header("Authorization", format!("Bearer {}", api_key.trim()))
658 .body(body)
659 .map(|request| client.send(request));
660
661 async move {
662 let mut response = request?.await?;
663 let mut body = String::new();
664 response.body_mut().read_to_string(&mut body).await?;
665
666 anyhow::ensure!(
667 response.status().is_success(),
668 "error during embedding, status: {:?}, body: {:?}",
669 response.status(),
670 body
671 );
672 let response: OpenAiEmbeddingResponse =
673 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
674 Ok(response)
675 }
676}