1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{
4 AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
5 http::{HeaderMap, HeaderValue},
6};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9pub use settings::OpenAiReasoningEffort as ReasoningEffort;
10use std::{convert::TryFrom, future::Future};
11use strum::EnumIter;
12use thiserror::Error;
13
14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
15
16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
17 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
18}
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26 Tool,
27}
28
29impl TryFrom<String> for Role {
30 type Error = anyhow::Error;
31
32 fn try_from(value: String) -> Result<Self> {
33 match value.as_str() {
34 "user" => Ok(Self::User),
35 "assistant" => Ok(Self::Assistant),
36 "system" => Ok(Self::System),
37 "tool" => Ok(Self::Tool),
38 _ => anyhow::bail!("invalid role '{value}'"),
39 }
40 }
41}
42
43impl From<Role> for String {
44 fn from(val: Role) -> Self {
45 match val {
46 Role::User => "user".to_owned(),
47 Role::Assistant => "assistant".to_owned(),
48 Role::System => "system".to_owned(),
49 Role::Tool => "tool".to_owned(),
50 }
51 }
52}
53
54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
56pub enum Model {
57 #[serde(rename = "gpt-3.5-turbo")]
58 ThreePointFiveTurbo,
59 #[serde(rename = "gpt-4")]
60 Four,
61 #[serde(rename = "gpt-4-turbo")]
62 FourTurbo,
63 #[serde(rename = "gpt-4o")]
64 #[default]
65 FourOmni,
66 #[serde(rename = "gpt-4o-mini")]
67 FourOmniMini,
68 #[serde(rename = "gpt-4.1")]
69 FourPointOne,
70 #[serde(rename = "gpt-4.1-mini")]
71 FourPointOneMini,
72 #[serde(rename = "gpt-4.1-nano")]
73 FourPointOneNano,
74 #[serde(rename = "o1")]
75 O1,
76 #[serde(rename = "o3-mini")]
77 O3Mini,
78 #[serde(rename = "o3")]
79 O3,
80 #[serde(rename = "o4-mini")]
81 O4Mini,
82 #[serde(rename = "gpt-5")]
83 Five,
84 #[serde(rename = "gpt-5-codex")]
85 FiveCodex,
86 #[serde(rename = "gpt-5-mini")]
87 FiveMini,
88 #[serde(rename = "gpt-5-nano")]
89 FiveNano,
90 #[serde(rename = "gpt-5.1")]
91 FivePointOne,
92 #[serde(rename = "gpt-5.2")]
93 FivePointTwo,
94 #[serde(rename = "gpt-5.2-codex")]
95 FivePointTwoCodex,
96 #[serde(rename = "custom")]
97 Custom {
98 name: String,
99 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
100 display_name: Option<String>,
101 max_tokens: u64,
102 max_output_tokens: Option<u64>,
103 max_completion_tokens: Option<u64>,
104 reasoning_effort: Option<ReasoningEffort>,
105 #[serde(default = "default_supports_chat_completions")]
106 supports_chat_completions: bool,
107 },
108}
109
110const fn default_supports_chat_completions() -> bool {
111 true
112}
113
114impl Model {
115 pub fn default_fast() -> Self {
116 // TODO: Replace with FiveMini since all other models are deprecated
117 Self::FourPointOneMini
118 }
119
120 pub fn from_id(id: &str) -> Result<Self> {
121 match id {
122 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
123 "gpt-4" => Ok(Self::Four),
124 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
125 "gpt-4o" => Ok(Self::FourOmni),
126 "gpt-4o-mini" => Ok(Self::FourOmniMini),
127 "gpt-4.1" => Ok(Self::FourPointOne),
128 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
129 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
130 "o1" => Ok(Self::O1),
131 "o3-mini" => Ok(Self::O3Mini),
132 "o3" => Ok(Self::O3),
133 "o4-mini" => Ok(Self::O4Mini),
134 "gpt-5" => Ok(Self::Five),
135 "gpt-5-codex" => Ok(Self::FiveCodex),
136 "gpt-5-mini" => Ok(Self::FiveMini),
137 "gpt-5-nano" => Ok(Self::FiveNano),
138 "gpt-5.1" => Ok(Self::FivePointOne),
139 "gpt-5.2" => Ok(Self::FivePointTwo),
140 "gpt-5.2-codex" => Ok(Self::FivePointTwoCodex),
141 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
142 }
143 }
144
145 pub fn id(&self) -> &str {
146 match self {
147 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
148 Self::Four => "gpt-4",
149 Self::FourTurbo => "gpt-4-turbo",
150 Self::FourOmni => "gpt-4o",
151 Self::FourOmniMini => "gpt-4o-mini",
152 Self::FourPointOne => "gpt-4.1",
153 Self::FourPointOneMini => "gpt-4.1-mini",
154 Self::FourPointOneNano => "gpt-4.1-nano",
155 Self::O1 => "o1",
156 Self::O3Mini => "o3-mini",
157 Self::O3 => "o3",
158 Self::O4Mini => "o4-mini",
159 Self::Five => "gpt-5",
160 Self::FiveCodex => "gpt-5-codex",
161 Self::FiveMini => "gpt-5-mini",
162 Self::FiveNano => "gpt-5-nano",
163 Self::FivePointOne => "gpt-5.1",
164 Self::FivePointTwo => "gpt-5.2",
165 Self::FivePointTwoCodex => "gpt-5.2-codex",
166 Self::Custom { name, .. } => name,
167 }
168 }
169
170 pub fn display_name(&self) -> &str {
171 match self {
172 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
173 Self::Four => "gpt-4",
174 Self::FourTurbo => "gpt-4-turbo",
175 Self::FourOmni => "gpt-4o",
176 Self::FourOmniMini => "gpt-4o-mini",
177 Self::FourPointOne => "gpt-4.1",
178 Self::FourPointOneMini => "gpt-4.1-mini",
179 Self::FourPointOneNano => "gpt-4.1-nano",
180 Self::O1 => "o1",
181 Self::O3Mini => "o3-mini",
182 Self::O3 => "o3",
183 Self::O4Mini => "o4-mini",
184 Self::Five => "gpt-5",
185 Self::FiveCodex => "gpt-5-codex",
186 Self::FiveMini => "gpt-5-mini",
187 Self::FiveNano => "gpt-5-nano",
188 Self::FivePointOne => "gpt-5.1",
189 Self::FivePointTwo => "gpt-5.2",
190 Self::FivePointTwoCodex => "gpt-5.2-codex",
191 Self::Custom {
192 name, display_name, ..
193 } => display_name.as_ref().unwrap_or(name),
194 }
195 }
196
197 pub fn max_token_count(&self) -> u64 {
198 match self {
199 Self::ThreePointFiveTurbo => 16_385,
200 Self::Four => 8_192,
201 Self::FourTurbo => 128_000,
202 Self::FourOmni => 128_000,
203 Self::FourOmniMini => 128_000,
204 Self::FourPointOne => 1_047_576,
205 Self::FourPointOneMini => 1_047_576,
206 Self::FourPointOneNano => 1_047_576,
207 Self::O1 => 200_000,
208 Self::O3Mini => 200_000,
209 Self::O3 => 200_000,
210 Self::O4Mini => 200_000,
211 Self::Five => 272_000,
212 Self::FiveCodex => 272_000,
213 Self::FiveMini => 272_000,
214 Self::FiveNano => 272_000,
215 Self::FivePointOne => 400_000,
216 Self::FivePointTwo => 400_000,
217 Self::FivePointTwoCodex => 400_000,
218 Self::Custom { max_tokens, .. } => *max_tokens,
219 }
220 }
221
222 pub fn max_output_tokens(&self) -> Option<u64> {
223 match self {
224 Self::Custom {
225 max_output_tokens, ..
226 } => *max_output_tokens,
227 Self::ThreePointFiveTurbo => Some(4_096),
228 Self::Four => Some(8_192),
229 Self::FourTurbo => Some(4_096),
230 Self::FourOmni => Some(16_384),
231 Self::FourOmniMini => Some(16_384),
232 Self::FourPointOne => Some(32_768),
233 Self::FourPointOneMini => Some(32_768),
234 Self::FourPointOneNano => Some(32_768),
235 Self::O1 => Some(100_000),
236 Self::O3Mini => Some(100_000),
237 Self::O3 => Some(100_000),
238 Self::O4Mini => Some(100_000),
239 Self::Five => Some(128_000),
240 Self::FiveCodex => Some(128_000),
241 Self::FiveMini => Some(128_000),
242 Self::FiveNano => Some(128_000),
243 Self::FivePointOne => Some(128_000),
244 Self::FivePointTwo => Some(128_000),
245 Self::FivePointTwoCodex => Some(128_000),
246 }
247 }
248
249 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
250 match self {
251 Self::Custom {
252 reasoning_effort, ..
253 } => reasoning_effort.to_owned(),
254 _ => None,
255 }
256 }
257
258 pub fn supports_chat_completions(&self) -> bool {
259 match self {
260 Self::Custom {
261 supports_chat_completions,
262 ..
263 } => *supports_chat_completions,
264 Self::FiveCodex | Self::FivePointTwoCodex => false,
265 _ => true,
266 }
267 }
268
269 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
270 ///
271 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
272 pub fn supports_parallel_tool_calls(&self) -> bool {
273 match self {
274 Self::ThreePointFiveTurbo
275 | Self::Four
276 | Self::FourTurbo
277 | Self::FourOmni
278 | Self::FourOmniMini
279 | Self::FourPointOne
280 | Self::FourPointOneMini
281 | Self::FourPointOneNano
282 | Self::Five
283 | Self::FiveCodex
284 | Self::FiveMini
285 | Self::FivePointOne
286 | Self::FivePointTwo
287 | Self::FivePointTwoCodex
288 | Self::FiveNano => true,
289 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
290 }
291 }
292
293 /// Returns whether the given model supports the `prompt_cache_key` parameter.
294 ///
295 /// If the model does not support the parameter, do not pass it up.
296 pub fn supports_prompt_cache_key(&self) -> bool {
297 true
298 }
299}
300
301#[derive(Debug, Serialize, Deserialize)]
302pub struct Request {
303 pub model: String,
304 pub messages: Vec<RequestMessage>,
305 pub stream: bool,
306 #[serde(default, skip_serializing_if = "Option::is_none")]
307 pub max_completion_tokens: Option<u64>,
308 #[serde(default, skip_serializing_if = "Vec::is_empty")]
309 pub stop: Vec<String>,
310 #[serde(default, skip_serializing_if = "Option::is_none")]
311 pub temperature: Option<f32>,
312 #[serde(default, skip_serializing_if = "Option::is_none")]
313 pub tool_choice: Option<ToolChoice>,
314 /// Whether to enable parallel function calling during tool use.
315 #[serde(default, skip_serializing_if = "Option::is_none")]
316 pub parallel_tool_calls: Option<bool>,
317 #[serde(default, skip_serializing_if = "Vec::is_empty")]
318 pub tools: Vec<ToolDefinition>,
319 #[serde(default, skip_serializing_if = "Option::is_none")]
320 pub prompt_cache_key: Option<String>,
321 #[serde(default, skip_serializing_if = "Option::is_none")]
322 pub reasoning_effort: Option<ReasoningEffort>,
323}
324
325#[derive(Debug, Serialize, Deserialize)]
326#[serde(rename_all = "lowercase")]
327pub enum ToolChoice {
328 Auto,
329 Required,
330 None,
331 #[serde(untagged)]
332 Other(ToolDefinition),
333}
334
335#[derive(Clone, Deserialize, Serialize, Debug)]
336#[serde(tag = "type", rename_all = "snake_case")]
337pub enum ToolDefinition {
338 #[allow(dead_code)]
339 Function { function: FunctionDefinition },
340}
341
342#[derive(Clone, Debug, Serialize, Deserialize)]
343pub struct FunctionDefinition {
344 pub name: String,
345 pub description: Option<String>,
346 pub parameters: Option<Value>,
347}
348
349#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
350#[serde(tag = "role", rename_all = "lowercase")]
351pub enum RequestMessage {
352 Assistant {
353 content: Option<MessageContent>,
354 #[serde(default, skip_serializing_if = "Vec::is_empty")]
355 tool_calls: Vec<ToolCall>,
356 },
357 User {
358 content: MessageContent,
359 },
360 System {
361 content: MessageContent,
362 },
363 Tool {
364 content: MessageContent,
365 tool_call_id: String,
366 },
367}
368
369#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
370#[serde(untagged)]
371pub enum MessageContent {
372 Plain(String),
373 Multipart(Vec<MessagePart>),
374}
375
376impl MessageContent {
377 pub fn empty() -> Self {
378 MessageContent::Multipart(vec![])
379 }
380
381 pub fn push_part(&mut self, part: MessagePart) {
382 match self {
383 MessageContent::Plain(text) => {
384 *self =
385 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
386 }
387 MessageContent::Multipart(parts) if parts.is_empty() => match part {
388 MessagePart::Text { text } => *self = MessageContent::Plain(text),
389 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
390 },
391 MessageContent::Multipart(parts) => parts.push(part),
392 }
393 }
394}
395
396impl From<Vec<MessagePart>> for MessageContent {
397 fn from(mut parts: Vec<MessagePart>) -> Self {
398 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
399 MessageContent::Plain(std::mem::take(text))
400 } else {
401 MessageContent::Multipart(parts)
402 }
403 }
404}
405
406#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
407#[serde(tag = "type")]
408pub enum MessagePart {
409 #[serde(rename = "text")]
410 Text { text: String },
411 #[serde(rename = "image_url")]
412 Image { image_url: ImageUrl },
413}
414
415#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
416pub struct ImageUrl {
417 pub url: String,
418 #[serde(skip_serializing_if = "Option::is_none")]
419 pub detail: Option<String>,
420}
421
422#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
423pub struct ToolCall {
424 pub id: String,
425 #[serde(flatten)]
426 pub content: ToolCallContent,
427}
428
429#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
430#[serde(tag = "type", rename_all = "lowercase")]
431pub enum ToolCallContent {
432 Function { function: FunctionContent },
433}
434
435#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
436pub struct FunctionContent {
437 pub name: String,
438 pub arguments: String,
439}
440
441#[derive(Clone, Serialize, Deserialize, Debug)]
442pub struct Response {
443 pub id: String,
444 pub object: String,
445 pub created: u64,
446 pub model: String,
447 pub choices: Vec<Choice>,
448 pub usage: Usage,
449}
450
451#[derive(Clone, Serialize, Deserialize, Debug)]
452pub struct Choice {
453 pub index: u32,
454 pub message: RequestMessage,
455 pub finish_reason: Option<String>,
456}
457
458#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
459pub struct ResponseMessageDelta {
460 pub role: Option<Role>,
461 pub content: Option<String>,
462 #[serde(default, skip_serializing_if = "is_none_or_empty")]
463 pub tool_calls: Option<Vec<ToolCallChunk>>,
464}
465
466#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
467pub struct ToolCallChunk {
468 pub index: usize,
469 pub id: Option<String>,
470
471 // There is also an optional `type` field that would determine if a
472 // function is there. Sometimes this streams in with the `function` before
473 // it streams in the `type`
474 pub function: Option<FunctionChunk>,
475}
476
477#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
478pub struct FunctionChunk {
479 pub name: Option<String>,
480 pub arguments: Option<String>,
481}
482
483#[derive(Clone, Serialize, Deserialize, Debug)]
484pub struct Usage {
485 pub prompt_tokens: u64,
486 pub completion_tokens: u64,
487 pub total_tokens: u64,
488}
489
490#[derive(Serialize, Deserialize, Debug)]
491pub struct ChoiceDelta {
492 pub index: u32,
493 pub delta: Option<ResponseMessageDelta>,
494 pub finish_reason: Option<String>,
495}
496
497#[derive(Error, Debug)]
498pub enum RequestError {
499 #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
500 HttpResponseError {
501 provider: String,
502 status_code: StatusCode,
503 body: String,
504 headers: HeaderMap<HeaderValue>,
505 },
506 #[error(transparent)]
507 Other(#[from] anyhow::Error),
508}
509
510#[derive(Serialize, Deserialize, Debug)]
511pub struct ResponseStreamError {
512 message: String,
513}
514
515#[derive(Serialize, Deserialize, Debug)]
516#[serde(untagged)]
517pub enum ResponseStreamResult {
518 Ok(ResponseStreamEvent),
519 Err { error: ResponseStreamError },
520}
521
522#[derive(Serialize, Deserialize, Debug)]
523pub struct ResponseStreamEvent {
524 pub choices: Vec<ChoiceDelta>,
525 pub usage: Option<Usage>,
526}
527
528pub async fn stream_completion(
529 client: &dyn HttpClient,
530 provider_name: &str,
531 api_url: &str,
532 api_key: &str,
533 request: Request,
534) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
535 let uri = format!("{api_url}/chat/completions");
536 let request_builder = HttpRequest::builder()
537 .method(Method::POST)
538 .uri(uri)
539 .header("Content-Type", "application/json")
540 .header("Authorization", format!("Bearer {}", api_key.trim()));
541
542 let request = request_builder
543 .body(AsyncBody::from(
544 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
545 ))
546 .map_err(|e| RequestError::Other(e.into()))?;
547
548 let mut response = client.send(request).await?;
549 if response.status().is_success() {
550 let reader = BufReader::new(response.into_body());
551 Ok(reader
552 .lines()
553 .filter_map(|line| async move {
554 match line {
555 Ok(line) => {
556 let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
557 if line == "[DONE]" {
558 None
559 } else {
560 match serde_json::from_str(line) {
561 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
562 Ok(ResponseStreamResult::Err { error }) => {
563 Some(Err(anyhow!(error.message)))
564 }
565 Err(error) => {
566 log::error!(
567 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
568 Response: `{}`",
569 error,
570 line,
571 );
572 Some(Err(anyhow!(error)))
573 }
574 }
575 }
576 }
577 Err(error) => Some(Err(anyhow!(error))),
578 }
579 })
580 .boxed())
581 } else {
582 let mut body = String::new();
583 response
584 .body_mut()
585 .read_to_string(&mut body)
586 .await
587 .map_err(|e| RequestError::Other(e.into()))?;
588
589 Err(RequestError::HttpResponseError {
590 provider: provider_name.to_owned(),
591 status_code: response.status(),
592 body,
593 headers: response.headers().clone(),
594 })
595 }
596}
597
598#[derive(Copy, Clone, Serialize, Deserialize)]
599pub enum OpenAiEmbeddingModel {
600 #[serde(rename = "text-embedding-3-small")]
601 TextEmbedding3Small,
602 #[serde(rename = "text-embedding-3-large")]
603 TextEmbedding3Large,
604}
605
606#[derive(Serialize)]
607struct OpenAiEmbeddingRequest<'a> {
608 model: OpenAiEmbeddingModel,
609 input: Vec<&'a str>,
610}
611
612#[derive(Deserialize)]
613pub struct OpenAiEmbeddingResponse {
614 pub data: Vec<OpenAiEmbedding>,
615}
616
617#[derive(Deserialize)]
618pub struct OpenAiEmbedding {
619 pub embedding: Vec<f32>,
620}
621
622pub fn embed<'a>(
623 client: &dyn HttpClient,
624 api_url: &str,
625 api_key: &str,
626 model: OpenAiEmbeddingModel,
627 texts: impl IntoIterator<Item = &'a str>,
628) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
629 let uri = format!("{api_url}/embeddings");
630
631 let request = OpenAiEmbeddingRequest {
632 model,
633 input: texts.into_iter().collect(),
634 };
635 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
636 let request = HttpRequest::builder()
637 .method(Method::POST)
638 .uri(uri)
639 .header("Content-Type", "application/json")
640 .header("Authorization", format!("Bearer {}", api_key.trim()))
641 .body(body)
642 .map(|request| client.send(request));
643
644 async move {
645 let mut response = request?.await?;
646 let mut body = String::new();
647 response.body_mut().read_to_string(&mut body).await?;
648
649 anyhow::ensure!(
650 response.status().is_success(),
651 "error during embedding, status: {:?}, body: {:?}",
652 response.status(),
653 body
654 );
655 let response: OpenAiEmbeddingResponse =
656 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
657 Ok(response)
658 }
659}
660
661pub mod responses {
662 use anyhow::{Result, anyhow};
663 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
664 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
665 use serde::{Deserialize, Serialize};
666 use serde_json::Value;
667
668 use crate::RequestError;
669
670 #[derive(Serialize, Debug)]
671 pub struct Request {
672 pub model: String,
673 #[serde(skip_serializing_if = "Vec::is_empty")]
674 pub input: Vec<Value>,
675 #[serde(default)]
676 pub stream: bool,
677 #[serde(skip_serializing_if = "Option::is_none")]
678 pub temperature: Option<f32>,
679 #[serde(skip_serializing_if = "Option::is_none")]
680 pub top_p: Option<f32>,
681 #[serde(skip_serializing_if = "Option::is_none")]
682 pub max_output_tokens: Option<u64>,
683 #[serde(skip_serializing_if = "Option::is_none")]
684 pub parallel_tool_calls: Option<bool>,
685 #[serde(skip_serializing_if = "Option::is_none")]
686 pub tool_choice: Option<super::ToolChoice>,
687 #[serde(skip_serializing_if = "Vec::is_empty")]
688 pub tools: Vec<ToolDefinition>,
689 #[serde(skip_serializing_if = "Option::is_none")]
690 pub prompt_cache_key: Option<String>,
691 #[serde(skip_serializing_if = "Option::is_none")]
692 pub reasoning: Option<ReasoningConfig>,
693 }
694
695 #[derive(Serialize, Debug)]
696 pub struct ReasoningConfig {
697 pub effort: super::ReasoningEffort,
698 }
699
700 #[derive(Serialize, Debug)]
701 #[serde(tag = "type", rename_all = "snake_case")]
702 pub enum ToolDefinition {
703 Function {
704 name: String,
705 #[serde(skip_serializing_if = "Option::is_none")]
706 description: Option<String>,
707 #[serde(skip_serializing_if = "Option::is_none")]
708 parameters: Option<Value>,
709 #[serde(skip_serializing_if = "Option::is_none")]
710 strict: Option<bool>,
711 },
712 }
713
714 #[derive(Deserialize, Debug)]
715 pub struct Error {
716 pub message: String,
717 }
718
719 #[derive(Deserialize, Debug)]
720 #[serde(tag = "type")]
721 pub enum StreamEvent {
722 #[serde(rename = "response.created")]
723 Created { response: ResponseSummary },
724 #[serde(rename = "response.in_progress")]
725 InProgress { response: ResponseSummary },
726 #[serde(rename = "response.output_item.added")]
727 OutputItemAdded {
728 output_index: usize,
729 #[serde(default)]
730 sequence_number: Option<u64>,
731 item: ResponseOutputItem,
732 },
733 #[serde(rename = "response.output_item.done")]
734 OutputItemDone {
735 output_index: usize,
736 #[serde(default)]
737 sequence_number: Option<u64>,
738 item: ResponseOutputItem,
739 },
740 #[serde(rename = "response.content_part.added")]
741 ContentPartAdded {
742 item_id: String,
743 output_index: usize,
744 content_index: usize,
745 part: Value,
746 },
747 #[serde(rename = "response.content_part.done")]
748 ContentPartDone {
749 item_id: String,
750 output_index: usize,
751 content_index: usize,
752 part: Value,
753 },
754 #[serde(rename = "response.output_text.delta")]
755 OutputTextDelta {
756 item_id: String,
757 output_index: usize,
758 #[serde(default)]
759 content_index: Option<usize>,
760 delta: String,
761 },
762 #[serde(rename = "response.output_text.done")]
763 OutputTextDone {
764 item_id: String,
765 output_index: usize,
766 #[serde(default)]
767 content_index: Option<usize>,
768 text: String,
769 },
770 #[serde(rename = "response.function_call_arguments.delta")]
771 FunctionCallArgumentsDelta {
772 item_id: String,
773 output_index: usize,
774 delta: String,
775 #[serde(default)]
776 sequence_number: Option<u64>,
777 },
778 #[serde(rename = "response.function_call_arguments.done")]
779 FunctionCallArgumentsDone {
780 item_id: String,
781 output_index: usize,
782 arguments: String,
783 #[serde(default)]
784 sequence_number: Option<u64>,
785 },
786 #[serde(rename = "response.completed")]
787 Completed { response: ResponseSummary },
788 #[serde(rename = "response.incomplete")]
789 Incomplete { response: ResponseSummary },
790 #[serde(rename = "response.failed")]
791 Failed { response: ResponseSummary },
792 #[serde(rename = "response.error")]
793 Error { error: Error },
794 #[serde(rename = "error")]
795 GenericError { error: Error },
796 #[serde(other)]
797 Unknown,
798 }
799
800 #[derive(Deserialize, Debug, Default, Clone)]
801 pub struct ResponseSummary {
802 #[serde(default)]
803 pub id: Option<String>,
804 #[serde(default)]
805 pub status: Option<String>,
806 #[serde(default)]
807 pub status_details: Option<ResponseStatusDetails>,
808 #[serde(default)]
809 pub usage: Option<ResponseUsage>,
810 #[serde(default)]
811 pub output: Vec<ResponseOutputItem>,
812 }
813
814 #[derive(Deserialize, Debug, Default, Clone)]
815 pub struct ResponseStatusDetails {
816 #[serde(default)]
817 pub reason: Option<String>,
818 #[serde(default)]
819 pub r#type: Option<String>,
820 #[serde(default)]
821 pub error: Option<Value>,
822 }
823
824 #[derive(Deserialize, Debug, Default, Clone)]
825 pub struct ResponseUsage {
826 #[serde(default)]
827 pub input_tokens: Option<u64>,
828 #[serde(default)]
829 pub output_tokens: Option<u64>,
830 #[serde(default)]
831 pub total_tokens: Option<u64>,
832 }
833
834 #[derive(Deserialize, Debug, Clone)]
835 #[serde(tag = "type", rename_all = "snake_case")]
836 pub enum ResponseOutputItem {
837 Message(ResponseOutputMessage),
838 FunctionCall(ResponseFunctionToolCall),
839 #[serde(other)]
840 Unknown,
841 }
842
843 #[derive(Deserialize, Debug, Clone)]
844 pub struct ResponseOutputMessage {
845 #[serde(default)]
846 pub id: Option<String>,
847 #[serde(default)]
848 pub content: Vec<Value>,
849 #[serde(default)]
850 pub role: Option<String>,
851 #[serde(default)]
852 pub status: Option<String>,
853 }
854
855 #[derive(Deserialize, Debug, Clone)]
856 pub struct ResponseFunctionToolCall {
857 #[serde(default)]
858 pub id: Option<String>,
859 #[serde(default)]
860 pub arguments: String,
861 #[serde(default)]
862 pub call_id: Option<String>,
863 #[serde(default)]
864 pub name: Option<String>,
865 #[serde(default)]
866 pub status: Option<String>,
867 }
868
869 pub async fn stream_response(
870 client: &dyn HttpClient,
871 provider_name: &str,
872 api_url: &str,
873 api_key: &str,
874 request: Request,
875 ) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
876 let uri = format!("{api_url}/responses");
877 let request_builder = HttpRequest::builder()
878 .method(Method::POST)
879 .uri(uri)
880 .header("Content-Type", "application/json")
881 .header("Authorization", format!("Bearer {}", api_key.trim()));
882
883 let is_streaming = request.stream;
884 let request = request_builder
885 .body(AsyncBody::from(
886 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
887 ))
888 .map_err(|e| RequestError::Other(e.into()))?;
889
890 let mut response = client.send(request).await?;
891 if response.status().is_success() {
892 if is_streaming {
893 let reader = BufReader::new(response.into_body());
894 Ok(reader
895 .lines()
896 .filter_map(|line| async move {
897 match line {
898 Ok(line) => {
899 let line = line
900 .strip_prefix("data: ")
901 .or_else(|| line.strip_prefix("data:"))?;
902 if line == "[DONE]" || line.is_empty() {
903 None
904 } else {
905 match serde_json::from_str::<StreamEvent>(line) {
906 Ok(event) => Some(Ok(event)),
907 Err(error) => {
908 log::error!(
909 "Failed to parse OpenAI responses stream event: `{}`\nResponse: `{}`",
910 error,
911 line,
912 );
913 Some(Err(anyhow!(error)))
914 }
915 }
916 }
917 }
918 Err(error) => Some(Err(anyhow!(error))),
919 }
920 })
921 .boxed())
922 } else {
923 let mut body = String::new();
924 response
925 .body_mut()
926 .read_to_string(&mut body)
927 .await
928 .map_err(|e| RequestError::Other(e.into()))?;
929
930 match serde_json::from_str::<ResponseSummary>(&body) {
931 Ok(response_summary) => {
932 let events = vec![
933 StreamEvent::Created {
934 response: response_summary.clone(),
935 },
936 StreamEvent::InProgress {
937 response: response_summary.clone(),
938 },
939 ];
940
941 let mut all_events = events;
942 for (output_index, item) in response_summary.output.iter().enumerate() {
943 all_events.push(StreamEvent::OutputItemAdded {
944 output_index,
945 sequence_number: None,
946 item: item.clone(),
947 });
948
949 match item {
950 ResponseOutputItem::Message(message) => {
951 for content_item in &message.content {
952 if let Some(text) = content_item.get("text") {
953 if let Some(text_str) = text.as_str() {
954 if let Some(ref item_id) = message.id {
955 all_events.push(StreamEvent::OutputTextDelta {
956 item_id: item_id.clone(),
957 output_index,
958 content_index: None,
959 delta: text_str.to_string(),
960 });
961 }
962 }
963 }
964 }
965 }
966 ResponseOutputItem::FunctionCall(function_call) => {
967 if let Some(ref item_id) = function_call.id {
968 all_events.push(StreamEvent::FunctionCallArgumentsDone {
969 item_id: item_id.clone(),
970 output_index,
971 arguments: function_call.arguments.clone(),
972 sequence_number: None,
973 });
974 }
975 }
976 ResponseOutputItem::Unknown => {}
977 }
978
979 all_events.push(StreamEvent::OutputItemDone {
980 output_index,
981 sequence_number: None,
982 item: item.clone(),
983 });
984 }
985
986 all_events.push(StreamEvent::Completed {
987 response: response_summary,
988 });
989
990 Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
991 }
992 Err(error) => {
993 log::error!(
994 "Failed to parse OpenAI non-streaming response: `{}`\nResponse: `{}`",
995 error,
996 body,
997 );
998 Err(RequestError::Other(anyhow!(error)))
999 }
1000 }
1001 }
1002 } else {
1003 let mut body = String::new();
1004 response
1005 .body_mut()
1006 .read_to_string(&mut body)
1007 .await
1008 .map_err(|e| RequestError::Other(e.into()))?;
1009
1010 Err(RequestError::HttpResponseError {
1011 provider: provider_name.to_owned(),
1012 status_code: response.status(),
1013 body,
1014 headers: response.headers().clone(),
1015 })
1016 }
1017 }
1018}