1pub mod responses;
2
3use anyhow::{Context as _, Result, anyhow};
4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
5use http_client::{
6 AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
7 http::{HeaderMap, HeaderValue},
8};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11pub use settings::OpenAiReasoningEffort as ReasoningEffort;
12use std::{convert::TryFrom, future::Future};
13use strum::EnumIter;
14use thiserror::Error;
15
16pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
17
18fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
19 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
20}
21
22#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25 User,
26 Assistant,
27 System,
28 Tool,
29}
30
31impl TryFrom<String> for Role {
32 type Error = anyhow::Error;
33
34 fn try_from(value: String) -> Result<Self> {
35 match value.as_str() {
36 "user" => Ok(Self::User),
37 "assistant" => Ok(Self::Assistant),
38 "system" => Ok(Self::System),
39 "tool" => Ok(Self::Tool),
40 _ => anyhow::bail!("invalid role '{value}'"),
41 }
42 }
43}
44
45impl From<Role> for String {
46 fn from(val: Role) -> Self {
47 match val {
48 Role::User => "user".to_owned(),
49 Role::Assistant => "assistant".to_owned(),
50 Role::System => "system".to_owned(),
51 Role::Tool => "tool".to_owned(),
52 }
53 }
54}
55
56#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
57#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
58pub enum Model {
59 #[serde(rename = "gpt-3.5-turbo")]
60 ThreePointFiveTurbo,
61 #[serde(rename = "gpt-4")]
62 Four,
63 #[serde(rename = "gpt-4-turbo")]
64 FourTurbo,
65 #[serde(rename = "gpt-4o")]
66 #[default]
67 FourOmni,
68 #[serde(rename = "gpt-4o-mini")]
69 FourOmniMini,
70 #[serde(rename = "gpt-4.1")]
71 FourPointOne,
72 #[serde(rename = "gpt-4.1-mini")]
73 FourPointOneMini,
74 #[serde(rename = "gpt-4.1-nano")]
75 FourPointOneNano,
76 #[serde(rename = "o1")]
77 O1,
78 #[serde(rename = "o3-mini")]
79 O3Mini,
80 #[serde(rename = "o3")]
81 O3,
82 #[serde(rename = "o4-mini")]
83 O4Mini,
84 #[serde(rename = "gpt-5")]
85 Five,
86 #[serde(rename = "gpt-5-codex")]
87 FiveCodex,
88 #[serde(rename = "gpt-5-mini")]
89 FiveMini,
90 #[serde(rename = "gpt-5-nano")]
91 FiveNano,
92 #[serde(rename = "gpt-5.1")]
93 FivePointOne,
94 #[serde(rename = "gpt-5.2")]
95 FivePointTwo,
96 #[serde(rename = "gpt-5.2-codex")]
97 FivePointTwoCodex,
98 #[serde(rename = "custom")]
99 Custom {
100 name: String,
101 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
102 display_name: Option<String>,
103 max_tokens: u64,
104 max_output_tokens: Option<u64>,
105 max_completion_tokens: Option<u64>,
106 reasoning_effort: Option<ReasoningEffort>,
107 #[serde(default = "default_supports_chat_completions")]
108 supports_chat_completions: bool,
109 },
110}
111
112const fn default_supports_chat_completions() -> bool {
113 true
114}
115
116impl Model {
117 pub fn default_fast() -> Self {
118 // TODO: Replace with FiveMini since all other models are deprecated
119 Self::FourPointOneMini
120 }
121
122 pub fn from_id(id: &str) -> Result<Self> {
123 match id {
124 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
125 "gpt-4" => Ok(Self::Four),
126 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
127 "gpt-4o" => Ok(Self::FourOmni),
128 "gpt-4o-mini" => Ok(Self::FourOmniMini),
129 "gpt-4.1" => Ok(Self::FourPointOne),
130 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
131 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
132 "o1" => Ok(Self::O1),
133 "o3-mini" => Ok(Self::O3Mini),
134 "o3" => Ok(Self::O3),
135 "o4-mini" => Ok(Self::O4Mini),
136 "gpt-5" => Ok(Self::Five),
137 "gpt-5-codex" => Ok(Self::FiveCodex),
138 "gpt-5-mini" => Ok(Self::FiveMini),
139 "gpt-5-nano" => Ok(Self::FiveNano),
140 "gpt-5.1" => Ok(Self::FivePointOne),
141 "gpt-5.2" => Ok(Self::FivePointTwo),
142 "gpt-5.2-codex" => Ok(Self::FivePointTwoCodex),
143 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
144 }
145 }
146
147 pub fn id(&self) -> &str {
148 match self {
149 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
150 Self::Four => "gpt-4",
151 Self::FourTurbo => "gpt-4-turbo",
152 Self::FourOmni => "gpt-4o",
153 Self::FourOmniMini => "gpt-4o-mini",
154 Self::FourPointOne => "gpt-4.1",
155 Self::FourPointOneMini => "gpt-4.1-mini",
156 Self::FourPointOneNano => "gpt-4.1-nano",
157 Self::O1 => "o1",
158 Self::O3Mini => "o3-mini",
159 Self::O3 => "o3",
160 Self::O4Mini => "o4-mini",
161 Self::Five => "gpt-5",
162 Self::FiveCodex => "gpt-5-codex",
163 Self::FiveMini => "gpt-5-mini",
164 Self::FiveNano => "gpt-5-nano",
165 Self::FivePointOne => "gpt-5.1",
166 Self::FivePointTwo => "gpt-5.2",
167 Self::FivePointTwoCodex => "gpt-5.2-codex",
168 Self::Custom { name, .. } => name,
169 }
170 }
171
172 pub fn display_name(&self) -> &str {
173 match self {
174 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
175 Self::Four => "gpt-4",
176 Self::FourTurbo => "gpt-4-turbo",
177 Self::FourOmni => "gpt-4o",
178 Self::FourOmniMini => "gpt-4o-mini",
179 Self::FourPointOne => "gpt-4.1",
180 Self::FourPointOneMini => "gpt-4.1-mini",
181 Self::FourPointOneNano => "gpt-4.1-nano",
182 Self::O1 => "o1",
183 Self::O3Mini => "o3-mini",
184 Self::O3 => "o3",
185 Self::O4Mini => "o4-mini",
186 Self::Five => "gpt-5",
187 Self::FiveCodex => "gpt-5-codex",
188 Self::FiveMini => "gpt-5-mini",
189 Self::FiveNano => "gpt-5-nano",
190 Self::FivePointOne => "gpt-5.1",
191 Self::FivePointTwo => "gpt-5.2",
192 Self::FivePointTwoCodex => "gpt-5.2-codex",
193 Self::Custom {
194 name, display_name, ..
195 } => display_name.as_ref().unwrap_or(name),
196 }
197 }
198
199 pub fn max_token_count(&self) -> u64 {
200 match self {
201 Self::ThreePointFiveTurbo => 16_385,
202 Self::Four => 8_192,
203 Self::FourTurbo => 128_000,
204 Self::FourOmni => 128_000,
205 Self::FourOmniMini => 128_000,
206 Self::FourPointOne => 1_047_576,
207 Self::FourPointOneMini => 1_047_576,
208 Self::FourPointOneNano => 1_047_576,
209 Self::O1 => 200_000,
210 Self::O3Mini => 200_000,
211 Self::O3 => 200_000,
212 Self::O4Mini => 200_000,
213 Self::Five => 272_000,
214 Self::FiveCodex => 272_000,
215 Self::FiveMini => 272_000,
216 Self::FiveNano => 272_000,
217 Self::FivePointOne => 400_000,
218 Self::FivePointTwo => 400_000,
219 Self::FivePointTwoCodex => 400_000,
220 Self::Custom { max_tokens, .. } => *max_tokens,
221 }
222 }
223
224 pub fn max_output_tokens(&self) -> Option<u64> {
225 match self {
226 Self::Custom {
227 max_output_tokens, ..
228 } => *max_output_tokens,
229 Self::ThreePointFiveTurbo => Some(4_096),
230 Self::Four => Some(8_192),
231 Self::FourTurbo => Some(4_096),
232 Self::FourOmni => Some(16_384),
233 Self::FourOmniMini => Some(16_384),
234 Self::FourPointOne => Some(32_768),
235 Self::FourPointOneMini => Some(32_768),
236 Self::FourPointOneNano => Some(32_768),
237 Self::O1 => Some(100_000),
238 Self::O3Mini => Some(100_000),
239 Self::O3 => Some(100_000),
240 Self::O4Mini => Some(100_000),
241 Self::Five => Some(128_000),
242 Self::FiveCodex => Some(128_000),
243 Self::FiveMini => Some(128_000),
244 Self::FiveNano => Some(128_000),
245 Self::FivePointOne => Some(128_000),
246 Self::FivePointTwo => Some(128_000),
247 Self::FivePointTwoCodex => Some(128_000),
248 }
249 }
250
251 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
252 match self {
253 Self::Custom {
254 reasoning_effort, ..
255 } => reasoning_effort.to_owned(),
256 _ => None,
257 }
258 }
259
260 pub fn supports_chat_completions(&self) -> bool {
261 match self {
262 Self::Custom {
263 supports_chat_completions,
264 ..
265 } => *supports_chat_completions,
266 Self::FiveCodex | Self::FivePointTwoCodex => false,
267 _ => true,
268 }
269 }
270
271 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
272 ///
273 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
274 pub fn supports_parallel_tool_calls(&self) -> bool {
275 match self {
276 Self::ThreePointFiveTurbo
277 | Self::Four
278 | Self::FourTurbo
279 | Self::FourOmni
280 | Self::FourOmniMini
281 | Self::FourPointOne
282 | Self::FourPointOneMini
283 | Self::FourPointOneNano
284 | Self::Five
285 | Self::FiveCodex
286 | Self::FiveMini
287 | Self::FivePointOne
288 | Self::FivePointTwo
289 | Self::FivePointTwoCodex
290 | Self::FiveNano => true,
291 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
292 }
293 }
294
295 /// Returns whether the given model supports the `prompt_cache_key` parameter.
296 ///
297 /// If the model does not support the parameter, do not pass it up.
298 pub fn supports_prompt_cache_key(&self) -> bool {
299 true
300 }
301}
302
303#[derive(Debug, Serialize, Deserialize)]
304pub struct Request {
305 pub model: String,
306 pub messages: Vec<RequestMessage>,
307 pub stream: bool,
308 #[serde(default, skip_serializing_if = "Option::is_none")]
309 pub max_completion_tokens: Option<u64>,
310 #[serde(default, skip_serializing_if = "Vec::is_empty")]
311 pub stop: Vec<String>,
312 #[serde(default, skip_serializing_if = "Option::is_none")]
313 pub temperature: Option<f32>,
314 #[serde(default, skip_serializing_if = "Option::is_none")]
315 pub tool_choice: Option<ToolChoice>,
316 /// Whether to enable parallel function calling during tool use.
317 #[serde(default, skip_serializing_if = "Option::is_none")]
318 pub parallel_tool_calls: Option<bool>,
319 #[serde(default, skip_serializing_if = "Vec::is_empty")]
320 pub tools: Vec<ToolDefinition>,
321 #[serde(default, skip_serializing_if = "Option::is_none")]
322 pub prompt_cache_key: Option<String>,
323 #[serde(default, skip_serializing_if = "Option::is_none")]
324 pub reasoning_effort: Option<ReasoningEffort>,
325}
326
327#[derive(Debug, Serialize, Deserialize)]
328#[serde(rename_all = "lowercase")]
329pub enum ToolChoice {
330 Auto,
331 Required,
332 None,
333 #[serde(untagged)]
334 Other(ToolDefinition),
335}
336
337#[derive(Clone, Deserialize, Serialize, Debug)]
338#[serde(tag = "type", rename_all = "snake_case")]
339pub enum ToolDefinition {
340 #[allow(dead_code)]
341 Function { function: FunctionDefinition },
342}
343
344#[derive(Clone, Debug, Serialize, Deserialize)]
345pub struct FunctionDefinition {
346 pub name: String,
347 pub description: Option<String>,
348 pub parameters: Option<Value>,
349}
350
351#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
352#[serde(tag = "role", rename_all = "lowercase")]
353pub enum RequestMessage {
354 Assistant {
355 content: Option<MessageContent>,
356 #[serde(default, skip_serializing_if = "Vec::is_empty")]
357 tool_calls: Vec<ToolCall>,
358 },
359 User {
360 content: MessageContent,
361 },
362 System {
363 content: MessageContent,
364 },
365 Tool {
366 content: MessageContent,
367 tool_call_id: String,
368 },
369}
370
371#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
372#[serde(untagged)]
373pub enum MessageContent {
374 Plain(String),
375 Multipart(Vec<MessagePart>),
376}
377
378impl MessageContent {
379 pub fn empty() -> Self {
380 MessageContent::Multipart(vec![])
381 }
382
383 pub fn push_part(&mut self, part: MessagePart) {
384 match self {
385 MessageContent::Plain(text) => {
386 *self =
387 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
388 }
389 MessageContent::Multipart(parts) if parts.is_empty() => match part {
390 MessagePart::Text { text } => *self = MessageContent::Plain(text),
391 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
392 },
393 MessageContent::Multipart(parts) => parts.push(part),
394 }
395 }
396}
397
398impl From<Vec<MessagePart>> for MessageContent {
399 fn from(mut parts: Vec<MessagePart>) -> Self {
400 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
401 MessageContent::Plain(std::mem::take(text))
402 } else {
403 MessageContent::Multipart(parts)
404 }
405 }
406}
407
408#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
409#[serde(tag = "type")]
410pub enum MessagePart {
411 #[serde(rename = "text")]
412 Text { text: String },
413 #[serde(rename = "image_url")]
414 Image { image_url: ImageUrl },
415}
416
417#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
418pub struct ImageUrl {
419 pub url: String,
420 #[serde(skip_serializing_if = "Option::is_none")]
421 pub detail: Option<String>,
422}
423
424#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
425pub struct ToolCall {
426 pub id: String,
427 #[serde(flatten)]
428 pub content: ToolCallContent,
429}
430
431#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
432#[serde(tag = "type", rename_all = "lowercase")]
433pub enum ToolCallContent {
434 Function { function: FunctionContent },
435}
436
437#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
438pub struct FunctionContent {
439 pub name: String,
440 pub arguments: String,
441}
442
443#[derive(Clone, Serialize, Deserialize, Debug)]
444pub struct Response {
445 pub id: String,
446 pub object: String,
447 pub created: u64,
448 pub model: String,
449 pub choices: Vec<Choice>,
450 pub usage: Usage,
451}
452
453#[derive(Clone, Serialize, Deserialize, Debug)]
454pub struct Choice {
455 pub index: u32,
456 pub message: RequestMessage,
457 pub finish_reason: Option<String>,
458}
459
460#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
461pub struct ResponseMessageDelta {
462 pub role: Option<Role>,
463 pub content: Option<String>,
464 #[serde(default, skip_serializing_if = "is_none_or_empty")]
465 pub tool_calls: Option<Vec<ToolCallChunk>>,
466}
467
468#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
469pub struct ToolCallChunk {
470 pub index: usize,
471 pub id: Option<String>,
472
473 // There is also an optional `type` field that would determine if a
474 // function is there. Sometimes this streams in with the `function` before
475 // it streams in the `type`
476 pub function: Option<FunctionChunk>,
477}
478
479#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
480pub struct FunctionChunk {
481 pub name: Option<String>,
482 pub arguments: Option<String>,
483}
484
485#[derive(Clone, Serialize, Deserialize, Debug)]
486pub struct Usage {
487 pub prompt_tokens: u64,
488 pub completion_tokens: u64,
489 pub total_tokens: u64,
490}
491
492#[derive(Serialize, Deserialize, Debug)]
493pub struct ChoiceDelta {
494 pub index: u32,
495 pub delta: Option<ResponseMessageDelta>,
496 pub finish_reason: Option<String>,
497}
498
499#[derive(Error, Debug)]
500pub enum RequestError {
501 #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
502 HttpResponseError {
503 provider: String,
504 status_code: StatusCode,
505 body: String,
506 headers: HeaderMap<HeaderValue>,
507 },
508 #[error(transparent)]
509 Other(#[from] anyhow::Error),
510}
511
512#[derive(Serialize, Deserialize, Debug)]
513pub struct ResponseStreamError {
514 message: String,
515}
516
517#[derive(Serialize, Deserialize, Debug)]
518#[serde(untagged)]
519pub enum ResponseStreamResult {
520 Ok(ResponseStreamEvent),
521 Err { error: ResponseStreamError },
522}
523
524#[derive(Serialize, Deserialize, Debug)]
525pub struct ResponseStreamEvent {
526 pub choices: Vec<ChoiceDelta>,
527 pub usage: Option<Usage>,
528}
529
530pub async fn stream_completion(
531 client: &dyn HttpClient,
532 provider_name: &str,
533 api_url: &str,
534 api_key: &str,
535 request: Request,
536) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
537 let uri = format!("{api_url}/chat/completions");
538 let request_builder = HttpRequest::builder()
539 .method(Method::POST)
540 .uri(uri)
541 .header("Content-Type", "application/json")
542 .header("Authorization", format!("Bearer {}", api_key.trim()));
543
544 let request = request_builder
545 .body(AsyncBody::from(
546 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
547 ))
548 .map_err(|e| RequestError::Other(e.into()))?;
549
550 let mut response = client.send(request).await?;
551 if response.status().is_success() {
552 let reader = BufReader::new(response.into_body());
553 Ok(reader
554 .lines()
555 .filter_map(|line| async move {
556 match line {
557 Ok(line) => {
558 let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
559 if line == "[DONE]" {
560 None
561 } else {
562 match serde_json::from_str(line) {
563 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
564 Ok(ResponseStreamResult::Err { error }) => {
565 Some(Err(anyhow!(error.message)))
566 }
567 Err(error) => {
568 log::error!(
569 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
570 Response: `{}`",
571 error,
572 line,
573 );
574 Some(Err(anyhow!(error)))
575 }
576 }
577 }
578 }
579 Err(error) => Some(Err(anyhow!(error))),
580 }
581 })
582 .boxed())
583 } else {
584 let mut body = String::new();
585 response
586 .body_mut()
587 .read_to_string(&mut body)
588 .await
589 .map_err(|e| RequestError::Other(e.into()))?;
590
591 Err(RequestError::HttpResponseError {
592 provider: provider_name.to_owned(),
593 status_code: response.status(),
594 body,
595 headers: response.headers().clone(),
596 })
597 }
598}
599
600#[derive(Copy, Clone, Serialize, Deserialize)]
601pub enum OpenAiEmbeddingModel {
602 #[serde(rename = "text-embedding-3-small")]
603 TextEmbedding3Small,
604 #[serde(rename = "text-embedding-3-large")]
605 TextEmbedding3Large,
606}
607
608#[derive(Serialize)]
609struct OpenAiEmbeddingRequest<'a> {
610 model: OpenAiEmbeddingModel,
611 input: Vec<&'a str>,
612}
613
614#[derive(Deserialize)]
615pub struct OpenAiEmbeddingResponse {
616 pub data: Vec<OpenAiEmbedding>,
617}
618
619#[derive(Deserialize)]
620pub struct OpenAiEmbedding {
621 pub embedding: Vec<f32>,
622}
623
624pub fn embed<'a>(
625 client: &dyn HttpClient,
626 api_url: &str,
627 api_key: &str,
628 model: OpenAiEmbeddingModel,
629 texts: impl IntoIterator<Item = &'a str>,
630) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
631 let uri = format!("{api_url}/embeddings");
632
633 let request = OpenAiEmbeddingRequest {
634 model,
635 input: texts.into_iter().collect(),
636 };
637 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
638 let request = HttpRequest::builder()
639 .method(Method::POST)
640 .uri(uri)
641 .header("Content-Type", "application/json")
642 .header("Authorization", format!("Bearer {}", api_key.trim()))
643 .body(body)
644 .map(|request| client.send(request));
645
646 async move {
647 let mut response = request?.await?;
648 let mut body = String::new();
649 response.body_mut().read_to_string(&mut body).await?;
650
651 anyhow::ensure!(
652 response.status().is_success(),
653 "error during embedding, status: {:?}, body: {:?}",
654 response.status(),
655 body
656 );
657 let response: OpenAiEmbeddingResponse =
658 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
659 Ok(response)
660 }
661}