1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{
4 AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
5 http::{HeaderMap, HeaderValue},
6};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9pub use settings::OpenAiReasoningEffort as ReasoningEffort;
10use std::{convert::TryFrom, future::Future};
11use strum::EnumIter;
12use thiserror::Error;
13
14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
15
16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
17 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
18}
19
20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23 User,
24 Assistant,
25 System,
26 Tool,
27}
28
29impl TryFrom<String> for Role {
30 type Error = anyhow::Error;
31
32 fn try_from(value: String) -> Result<Self> {
33 match value.as_str() {
34 "user" => Ok(Self::User),
35 "assistant" => Ok(Self::Assistant),
36 "system" => Ok(Self::System),
37 "tool" => Ok(Self::Tool),
38 _ => anyhow::bail!("invalid role '{value}'"),
39 }
40 }
41}
42
43impl From<Role> for String {
44 fn from(val: Role) -> Self {
45 match val {
46 Role::User => "user".to_owned(),
47 Role::Assistant => "assistant".to_owned(),
48 Role::System => "system".to_owned(),
49 Role::Tool => "tool".to_owned(),
50 }
51 }
52}
53
54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
56pub enum Model {
57 #[serde(rename = "gpt-3.5-turbo")]
58 ThreePointFiveTurbo,
59 #[serde(rename = "gpt-4")]
60 Four,
61 #[serde(rename = "gpt-4-turbo")]
62 FourTurbo,
63 #[serde(rename = "gpt-4o")]
64 #[default]
65 FourOmni,
66 #[serde(rename = "gpt-4o-mini")]
67 FourOmniMini,
68 #[serde(rename = "gpt-4.1")]
69 FourPointOne,
70 #[serde(rename = "gpt-4.1-mini")]
71 FourPointOneMini,
72 #[serde(rename = "gpt-4.1-nano")]
73 FourPointOneNano,
74 #[serde(rename = "o1")]
75 O1,
76 #[serde(rename = "o3-mini")]
77 O3Mini,
78 #[serde(rename = "o3")]
79 O3,
80 #[serde(rename = "o4-mini")]
81 O4Mini,
82 #[serde(rename = "gpt-5")]
83 Five,
84 #[serde(rename = "gpt-5-codex")]
85 FiveCodex,
86 #[serde(rename = "gpt-5-mini")]
87 FiveMini,
88 #[serde(rename = "gpt-5-nano")]
89 FiveNano,
90 #[serde(rename = "gpt-5.1")]
91 FivePointOne,
92 #[serde(rename = "gpt-5.2")]
93 FivePointTwo,
94 #[serde(rename = "custom")]
95 Custom {
96 name: String,
97 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
98 display_name: Option<String>,
99 max_tokens: u64,
100 max_output_tokens: Option<u64>,
101 max_completion_tokens: Option<u64>,
102 reasoning_effort: Option<ReasoningEffort>,
103 #[serde(default = "default_supports_chat_completions")]
104 supports_chat_completions: bool,
105 },
106}
107
108const fn default_supports_chat_completions() -> bool {
109 true
110}
111
112impl Model {
113 pub fn default_fast() -> Self {
114 // TODO: Replace with FiveMini since all other models are deprecated
115 Self::FourPointOneMini
116 }
117
118 pub fn from_id(id: &str) -> Result<Self> {
119 match id {
120 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
121 "gpt-4" => Ok(Self::Four),
122 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
123 "gpt-4o" => Ok(Self::FourOmni),
124 "gpt-4o-mini" => Ok(Self::FourOmniMini),
125 "gpt-4.1" => Ok(Self::FourPointOne),
126 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
127 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
128 "o1" => Ok(Self::O1),
129 "o3-mini" => Ok(Self::O3Mini),
130 "o3" => Ok(Self::O3),
131 "o4-mini" => Ok(Self::O4Mini),
132 "gpt-5" => Ok(Self::Five),
133 "gpt-5-codex" => Ok(Self::FiveCodex),
134 "gpt-5-mini" => Ok(Self::FiveMini),
135 "gpt-5-nano" => Ok(Self::FiveNano),
136 "gpt-5.1" => Ok(Self::FivePointOne),
137 "gpt-5.2" => Ok(Self::FivePointTwo),
138 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
139 }
140 }
141
142 pub fn id(&self) -> &str {
143 match self {
144 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
145 Self::Four => "gpt-4",
146 Self::FourTurbo => "gpt-4-turbo",
147 Self::FourOmni => "gpt-4o",
148 Self::FourOmniMini => "gpt-4o-mini",
149 Self::FourPointOne => "gpt-4.1",
150 Self::FourPointOneMini => "gpt-4.1-mini",
151 Self::FourPointOneNano => "gpt-4.1-nano",
152 Self::O1 => "o1",
153 Self::O3Mini => "o3-mini",
154 Self::O3 => "o3",
155 Self::O4Mini => "o4-mini",
156 Self::Five => "gpt-5",
157 Self::FiveCodex => "gpt-5-codex",
158 Self::FiveMini => "gpt-5-mini",
159 Self::FiveNano => "gpt-5-nano",
160 Self::FivePointOne => "gpt-5.1",
161 Self::FivePointTwo => "gpt-5.2",
162 Self::Custom { name, .. } => name,
163 }
164 }
165
166 pub fn display_name(&self) -> &str {
167 match self {
168 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
169 Self::Four => "gpt-4",
170 Self::FourTurbo => "gpt-4-turbo",
171 Self::FourOmni => "gpt-4o",
172 Self::FourOmniMini => "gpt-4o-mini",
173 Self::FourPointOne => "gpt-4.1",
174 Self::FourPointOneMini => "gpt-4.1-mini",
175 Self::FourPointOneNano => "gpt-4.1-nano",
176 Self::O1 => "o1",
177 Self::O3Mini => "o3-mini",
178 Self::O3 => "o3",
179 Self::O4Mini => "o4-mini",
180 Self::Five => "gpt-5",
181 Self::FiveCodex => "gpt-5-codex",
182 Self::FiveMini => "gpt-5-mini",
183 Self::FiveNano => "gpt-5-nano",
184 Self::FivePointOne => "gpt-5.1",
185 Self::FivePointTwo => "gpt-5.2",
186 Self::Custom {
187 name, display_name, ..
188 } => display_name.as_ref().unwrap_or(name),
189 }
190 }
191
192 pub fn max_token_count(&self) -> u64 {
193 match self {
194 Self::ThreePointFiveTurbo => 16_385,
195 Self::Four => 8_192,
196 Self::FourTurbo => 128_000,
197 Self::FourOmni => 128_000,
198 Self::FourOmniMini => 128_000,
199 Self::FourPointOne => 1_047_576,
200 Self::FourPointOneMini => 1_047_576,
201 Self::FourPointOneNano => 1_047_576,
202 Self::O1 => 200_000,
203 Self::O3Mini => 200_000,
204 Self::O3 => 200_000,
205 Self::O4Mini => 200_000,
206 Self::Five => 272_000,
207 Self::FiveCodex => 272_000,
208 Self::FiveMini => 272_000,
209 Self::FiveNano => 272_000,
210 Self::FivePointOne => 400_000,
211 Self::FivePointTwo => 400_000,
212 Self::Custom { max_tokens, .. } => *max_tokens,
213 }
214 }
215
216 pub fn max_output_tokens(&self) -> Option<u64> {
217 match self {
218 Self::Custom {
219 max_output_tokens, ..
220 } => *max_output_tokens,
221 Self::ThreePointFiveTurbo => Some(4_096),
222 Self::Four => Some(8_192),
223 Self::FourTurbo => Some(4_096),
224 Self::FourOmni => Some(16_384),
225 Self::FourOmniMini => Some(16_384),
226 Self::FourPointOne => Some(32_768),
227 Self::FourPointOneMini => Some(32_768),
228 Self::FourPointOneNano => Some(32_768),
229 Self::O1 => Some(100_000),
230 Self::O3Mini => Some(100_000),
231 Self::O3 => Some(100_000),
232 Self::O4Mini => Some(100_000),
233 Self::Five => Some(128_000),
234 Self::FiveCodex => Some(128_000),
235 Self::FiveMini => Some(128_000),
236 Self::FiveNano => Some(128_000),
237 Self::FivePointOne => Some(128_000),
238 Self::FivePointTwo => Some(128_000),
239 }
240 }
241
242 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
243 match self {
244 Self::Custom {
245 reasoning_effort, ..
246 } => reasoning_effort.to_owned(),
247 _ => None,
248 }
249 }
250
251 pub fn supports_chat_completions(&self) -> bool {
252 match self {
253 Self::Custom {
254 supports_chat_completions,
255 ..
256 } => *supports_chat_completions,
257 Self::FiveCodex => false,
258 _ => true,
259 }
260 }
261
262 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
263 ///
264 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
265 pub fn supports_parallel_tool_calls(&self) -> bool {
266 match self {
267 Self::ThreePointFiveTurbo
268 | Self::Four
269 | Self::FourTurbo
270 | Self::FourOmni
271 | Self::FourOmniMini
272 | Self::FourPointOne
273 | Self::FourPointOneMini
274 | Self::FourPointOneNano
275 | Self::Five
276 | Self::FiveCodex
277 | Self::FiveMini
278 | Self::FivePointOne
279 | Self::FivePointTwo
280 | Self::FiveNano => true,
281 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
282 }
283 }
284
285 /// Returns whether the given model supports the `prompt_cache_key` parameter.
286 ///
287 /// If the model does not support the parameter, do not pass it up.
288 pub fn supports_prompt_cache_key(&self) -> bool {
289 true
290 }
291}
292
293#[derive(Debug, Serialize, Deserialize)]
294pub struct Request {
295 pub model: String,
296 pub messages: Vec<RequestMessage>,
297 pub stream: bool,
298 #[serde(default, skip_serializing_if = "Option::is_none")]
299 pub max_completion_tokens: Option<u64>,
300 #[serde(default, skip_serializing_if = "Vec::is_empty")]
301 pub stop: Vec<String>,
302 #[serde(default, skip_serializing_if = "Option::is_none")]
303 pub temperature: Option<f32>,
304 #[serde(default, skip_serializing_if = "Option::is_none")]
305 pub tool_choice: Option<ToolChoice>,
306 /// Whether to enable parallel function calling during tool use.
307 #[serde(default, skip_serializing_if = "Option::is_none")]
308 pub parallel_tool_calls: Option<bool>,
309 #[serde(default, skip_serializing_if = "Vec::is_empty")]
310 pub tools: Vec<ToolDefinition>,
311 #[serde(default, skip_serializing_if = "Option::is_none")]
312 pub prompt_cache_key: Option<String>,
313 #[serde(default, skip_serializing_if = "Option::is_none")]
314 pub reasoning_effort: Option<ReasoningEffort>,
315}
316
317#[derive(Debug, Serialize, Deserialize)]
318#[serde(rename_all = "lowercase")]
319pub enum ToolChoice {
320 Auto,
321 Required,
322 None,
323 #[serde(untagged)]
324 Other(ToolDefinition),
325}
326
327#[derive(Clone, Deserialize, Serialize, Debug)]
328#[serde(tag = "type", rename_all = "snake_case")]
329pub enum ToolDefinition {
330 #[allow(dead_code)]
331 Function { function: FunctionDefinition },
332}
333
334#[derive(Clone, Debug, Serialize, Deserialize)]
335pub struct FunctionDefinition {
336 pub name: String,
337 pub description: Option<String>,
338 pub parameters: Option<Value>,
339}
340
341#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
342#[serde(tag = "role", rename_all = "lowercase")]
343pub enum RequestMessage {
344 Assistant {
345 content: Option<MessageContent>,
346 #[serde(default, skip_serializing_if = "Vec::is_empty")]
347 tool_calls: Vec<ToolCall>,
348 },
349 User {
350 content: MessageContent,
351 },
352 System {
353 content: MessageContent,
354 },
355 Tool {
356 content: MessageContent,
357 tool_call_id: String,
358 },
359}
360
361#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
362#[serde(untagged)]
363pub enum MessageContent {
364 Plain(String),
365 Multipart(Vec<MessagePart>),
366}
367
368impl MessageContent {
369 pub fn empty() -> Self {
370 MessageContent::Multipart(vec![])
371 }
372
373 pub fn push_part(&mut self, part: MessagePart) {
374 match self {
375 MessageContent::Plain(text) => {
376 *self =
377 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
378 }
379 MessageContent::Multipart(parts) if parts.is_empty() => match part {
380 MessagePart::Text { text } => *self = MessageContent::Plain(text),
381 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
382 },
383 MessageContent::Multipart(parts) => parts.push(part),
384 }
385 }
386}
387
388impl From<Vec<MessagePart>> for MessageContent {
389 fn from(mut parts: Vec<MessagePart>) -> Self {
390 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
391 MessageContent::Plain(std::mem::take(text))
392 } else {
393 MessageContent::Multipart(parts)
394 }
395 }
396}
397
398#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
399#[serde(tag = "type")]
400pub enum MessagePart {
401 #[serde(rename = "text")]
402 Text { text: String },
403 #[serde(rename = "image_url")]
404 Image { image_url: ImageUrl },
405}
406
407#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
408pub struct ImageUrl {
409 pub url: String,
410 #[serde(skip_serializing_if = "Option::is_none")]
411 pub detail: Option<String>,
412}
413
414#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
415pub struct ToolCall {
416 pub id: String,
417 #[serde(flatten)]
418 pub content: ToolCallContent,
419}
420
421#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
422#[serde(tag = "type", rename_all = "lowercase")]
423pub enum ToolCallContent {
424 Function { function: FunctionContent },
425}
426
427#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
428pub struct FunctionContent {
429 pub name: String,
430 pub arguments: String,
431}
432
433#[derive(Clone, Serialize, Deserialize, Debug)]
434pub struct Response {
435 pub id: String,
436 pub object: String,
437 pub created: u64,
438 pub model: String,
439 pub choices: Vec<Choice>,
440 pub usage: Usage,
441}
442
443#[derive(Clone, Serialize, Deserialize, Debug)]
444pub struct Choice {
445 pub index: u32,
446 pub message: RequestMessage,
447 pub finish_reason: Option<String>,
448}
449
450#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
451pub struct ResponseMessageDelta {
452 pub role: Option<Role>,
453 pub content: Option<String>,
454 #[serde(default, skip_serializing_if = "is_none_or_empty")]
455 pub tool_calls: Option<Vec<ToolCallChunk>>,
456}
457
458#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
459pub struct ToolCallChunk {
460 pub index: usize,
461 pub id: Option<String>,
462
463 // There is also an optional `type` field that would determine if a
464 // function is there. Sometimes this streams in with the `function` before
465 // it streams in the `type`
466 pub function: Option<FunctionChunk>,
467}
468
469#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
470pub struct FunctionChunk {
471 pub name: Option<String>,
472 pub arguments: Option<String>,
473}
474
475#[derive(Clone, Serialize, Deserialize, Debug)]
476pub struct Usage {
477 pub prompt_tokens: u64,
478 pub completion_tokens: u64,
479 pub total_tokens: u64,
480}
481
482#[derive(Serialize, Deserialize, Debug)]
483pub struct ChoiceDelta {
484 pub index: u32,
485 pub delta: Option<ResponseMessageDelta>,
486 pub finish_reason: Option<String>,
487}
488
489#[derive(Error, Debug)]
490pub enum RequestError {
491 #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
492 HttpResponseError {
493 provider: String,
494 status_code: StatusCode,
495 body: String,
496 headers: HeaderMap<HeaderValue>,
497 },
498 #[error(transparent)]
499 Other(#[from] anyhow::Error),
500}
501
502#[derive(Serialize, Deserialize, Debug)]
503pub struct ResponseStreamError {
504 message: String,
505}
506
507#[derive(Serialize, Deserialize, Debug)]
508#[serde(untagged)]
509pub enum ResponseStreamResult {
510 Ok(ResponseStreamEvent),
511 Err { error: ResponseStreamError },
512}
513
514#[derive(Serialize, Deserialize, Debug)]
515pub struct ResponseStreamEvent {
516 pub choices: Vec<ChoiceDelta>,
517 pub usage: Option<Usage>,
518}
519
520pub async fn stream_completion(
521 client: &dyn HttpClient,
522 provider_name: &str,
523 api_url: &str,
524 api_key: &str,
525 request: Request,
526) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
527 let uri = format!("{api_url}/chat/completions");
528 let request_builder = HttpRequest::builder()
529 .method(Method::POST)
530 .uri(uri)
531 .header("Content-Type", "application/json")
532 .header("Authorization", format!("Bearer {}", api_key.trim()));
533
534 let request = request_builder
535 .body(AsyncBody::from(
536 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
537 ))
538 .map_err(|e| RequestError::Other(e.into()))?;
539
540 let mut response = client.send(request).await?;
541 if response.status().is_success() {
542 let reader = BufReader::new(response.into_body());
543 Ok(reader
544 .lines()
545 .filter_map(|line| async move {
546 match line {
547 Ok(line) => {
548 let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
549 if line == "[DONE]" {
550 None
551 } else {
552 match serde_json::from_str(line) {
553 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
554 Ok(ResponseStreamResult::Err { error }) => {
555 Some(Err(anyhow!(error.message)))
556 }
557 Err(error) => {
558 log::error!(
559 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
560 Response: `{}`",
561 error,
562 line,
563 );
564 Some(Err(anyhow!(error)))
565 }
566 }
567 }
568 }
569 Err(error) => Some(Err(anyhow!(error))),
570 }
571 })
572 .boxed())
573 } else {
574 let mut body = String::new();
575 response
576 .body_mut()
577 .read_to_string(&mut body)
578 .await
579 .map_err(|e| RequestError::Other(e.into()))?;
580
581 Err(RequestError::HttpResponseError {
582 provider: provider_name.to_owned(),
583 status_code: response.status(),
584 body,
585 headers: response.headers().clone(),
586 })
587 }
588}
589
590#[derive(Copy, Clone, Serialize, Deserialize)]
591pub enum OpenAiEmbeddingModel {
592 #[serde(rename = "text-embedding-3-small")]
593 TextEmbedding3Small,
594 #[serde(rename = "text-embedding-3-large")]
595 TextEmbedding3Large,
596}
597
598#[derive(Serialize)]
599struct OpenAiEmbeddingRequest<'a> {
600 model: OpenAiEmbeddingModel,
601 input: Vec<&'a str>,
602}
603
604#[derive(Deserialize)]
605pub struct OpenAiEmbeddingResponse {
606 pub data: Vec<OpenAiEmbedding>,
607}
608
609#[derive(Deserialize)]
610pub struct OpenAiEmbedding {
611 pub embedding: Vec<f32>,
612}
613
614pub fn embed<'a>(
615 client: &dyn HttpClient,
616 api_url: &str,
617 api_key: &str,
618 model: OpenAiEmbeddingModel,
619 texts: impl IntoIterator<Item = &'a str>,
620) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
621 let uri = format!("{api_url}/embeddings");
622
623 let request = OpenAiEmbeddingRequest {
624 model,
625 input: texts.into_iter().collect(),
626 };
627 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
628 let request = HttpRequest::builder()
629 .method(Method::POST)
630 .uri(uri)
631 .header("Content-Type", "application/json")
632 .header("Authorization", format!("Bearer {}", api_key.trim()))
633 .body(body)
634 .map(|request| client.send(request));
635
636 async move {
637 let mut response = request?.await?;
638 let mut body = String::new();
639 response.body_mut().read_to_string(&mut body).await?;
640
641 anyhow::ensure!(
642 response.status().is_success(),
643 "error during embedding, status: {:?}, body: {:?}",
644 response.status(),
645 body
646 );
647 let response: OpenAiEmbeddingResponse =
648 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
649 Ok(response)
650 }
651}
652
653pub mod responses {
654 use anyhow::{Result, anyhow};
655 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
656 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
657 use serde::{Deserialize, Serialize};
658 use serde_json::Value;
659
660 use crate::RequestError;
661
662 #[derive(Serialize, Debug)]
663 pub struct Request {
664 pub model: String,
665 #[serde(skip_serializing_if = "Vec::is_empty")]
666 pub input: Vec<Value>,
667 #[serde(default)]
668 pub stream: bool,
669 #[serde(skip_serializing_if = "Option::is_none")]
670 pub temperature: Option<f32>,
671 #[serde(skip_serializing_if = "Option::is_none")]
672 pub top_p: Option<f32>,
673 #[serde(skip_serializing_if = "Option::is_none")]
674 pub max_output_tokens: Option<u64>,
675 #[serde(skip_serializing_if = "Option::is_none")]
676 pub parallel_tool_calls: Option<bool>,
677 #[serde(skip_serializing_if = "Option::is_none")]
678 pub tool_choice: Option<super::ToolChoice>,
679 #[serde(skip_serializing_if = "Vec::is_empty")]
680 pub tools: Vec<ToolDefinition>,
681 #[serde(skip_serializing_if = "Option::is_none")]
682 pub prompt_cache_key: Option<String>,
683 #[serde(skip_serializing_if = "Option::is_none")]
684 pub reasoning: Option<ReasoningConfig>,
685 }
686
687 #[derive(Serialize, Debug)]
688 pub struct ReasoningConfig {
689 pub effort: super::ReasoningEffort,
690 }
691
692 #[derive(Serialize, Debug)]
693 #[serde(tag = "type", rename_all = "snake_case")]
694 pub enum ToolDefinition {
695 Function {
696 name: String,
697 #[serde(skip_serializing_if = "Option::is_none")]
698 description: Option<String>,
699 #[serde(skip_serializing_if = "Option::is_none")]
700 parameters: Option<Value>,
701 #[serde(skip_serializing_if = "Option::is_none")]
702 strict: Option<bool>,
703 },
704 }
705
706 #[derive(Deserialize, Debug)]
707 pub struct Error {
708 pub message: String,
709 }
710
711 #[derive(Deserialize, Debug)]
712 #[serde(tag = "type")]
713 pub enum StreamEvent {
714 #[serde(rename = "response.created")]
715 Created { response: ResponseSummary },
716 #[serde(rename = "response.in_progress")]
717 InProgress { response: ResponseSummary },
718 #[serde(rename = "response.output_item.added")]
719 OutputItemAdded {
720 output_index: usize,
721 #[serde(default)]
722 sequence_number: Option<u64>,
723 item: ResponseOutputItem,
724 },
725 #[serde(rename = "response.output_item.done")]
726 OutputItemDone {
727 output_index: usize,
728 #[serde(default)]
729 sequence_number: Option<u64>,
730 item: ResponseOutputItem,
731 },
732 #[serde(rename = "response.content_part.added")]
733 ContentPartAdded {
734 item_id: String,
735 output_index: usize,
736 content_index: usize,
737 part: Value,
738 },
739 #[serde(rename = "response.content_part.done")]
740 ContentPartDone {
741 item_id: String,
742 output_index: usize,
743 content_index: usize,
744 part: Value,
745 },
746 #[serde(rename = "response.output_text.delta")]
747 OutputTextDelta {
748 item_id: String,
749 output_index: usize,
750 #[serde(default)]
751 content_index: Option<usize>,
752 delta: String,
753 },
754 #[serde(rename = "response.output_text.done")]
755 OutputTextDone {
756 item_id: String,
757 output_index: usize,
758 #[serde(default)]
759 content_index: Option<usize>,
760 text: String,
761 },
762 #[serde(rename = "response.function_call_arguments.delta")]
763 FunctionCallArgumentsDelta {
764 item_id: String,
765 output_index: usize,
766 delta: String,
767 #[serde(default)]
768 sequence_number: Option<u64>,
769 },
770 #[serde(rename = "response.function_call_arguments.done")]
771 FunctionCallArgumentsDone {
772 item_id: String,
773 output_index: usize,
774 arguments: String,
775 #[serde(default)]
776 sequence_number: Option<u64>,
777 },
778 #[serde(rename = "response.completed")]
779 Completed { response: ResponseSummary },
780 #[serde(rename = "response.incomplete")]
781 Incomplete { response: ResponseSummary },
782 #[serde(rename = "response.failed")]
783 Failed { response: ResponseSummary },
784 #[serde(rename = "response.error")]
785 Error { error: Error },
786 #[serde(rename = "error")]
787 GenericError { error: Error },
788 #[serde(other)]
789 Unknown,
790 }
791
792 #[derive(Deserialize, Debug, Default, Clone)]
793 pub struct ResponseSummary {
794 #[serde(default)]
795 pub id: Option<String>,
796 #[serde(default)]
797 pub status: Option<String>,
798 #[serde(default)]
799 pub status_details: Option<ResponseStatusDetails>,
800 #[serde(default)]
801 pub usage: Option<ResponseUsage>,
802 #[serde(default)]
803 pub output: Vec<ResponseOutputItem>,
804 }
805
806 #[derive(Deserialize, Debug, Default, Clone)]
807 pub struct ResponseStatusDetails {
808 #[serde(default)]
809 pub reason: Option<String>,
810 #[serde(default)]
811 pub r#type: Option<String>,
812 #[serde(default)]
813 pub error: Option<Value>,
814 }
815
816 #[derive(Deserialize, Debug, Default, Clone)]
817 pub struct ResponseUsage {
818 #[serde(default)]
819 pub input_tokens: Option<u64>,
820 #[serde(default)]
821 pub output_tokens: Option<u64>,
822 #[serde(default)]
823 pub total_tokens: Option<u64>,
824 }
825
826 #[derive(Deserialize, Debug, Clone)]
827 #[serde(tag = "type", rename_all = "snake_case")]
828 pub enum ResponseOutputItem {
829 Message(ResponseOutputMessage),
830 FunctionCall(ResponseFunctionToolCall),
831 #[serde(other)]
832 Unknown,
833 }
834
835 #[derive(Deserialize, Debug, Clone)]
836 pub struct ResponseOutputMessage {
837 #[serde(default)]
838 pub id: Option<String>,
839 #[serde(default)]
840 pub content: Vec<Value>,
841 #[serde(default)]
842 pub role: Option<String>,
843 #[serde(default)]
844 pub status: Option<String>,
845 }
846
847 #[derive(Deserialize, Debug, Clone)]
848 pub struct ResponseFunctionToolCall {
849 #[serde(default)]
850 pub id: Option<String>,
851 #[serde(default)]
852 pub arguments: String,
853 #[serde(default)]
854 pub call_id: Option<String>,
855 #[serde(default)]
856 pub name: Option<String>,
857 #[serde(default)]
858 pub status: Option<String>,
859 }
860
861 pub async fn stream_response(
862 client: &dyn HttpClient,
863 provider_name: &str,
864 api_url: &str,
865 api_key: &str,
866 request: Request,
867 ) -> Result<BoxStream<'static, Result<StreamEvent>>, RequestError> {
868 let uri = format!("{api_url}/responses");
869 let request_builder = HttpRequest::builder()
870 .method(Method::POST)
871 .uri(uri)
872 .header("Content-Type", "application/json")
873 .header("Authorization", format!("Bearer {}", api_key.trim()));
874
875 let is_streaming = request.stream;
876 let request = request_builder
877 .body(AsyncBody::from(
878 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
879 ))
880 .map_err(|e| RequestError::Other(e.into()))?;
881
882 let mut response = client.send(request).await?;
883 if response.status().is_success() {
884 if is_streaming {
885 let reader = BufReader::new(response.into_body());
886 Ok(reader
887 .lines()
888 .filter_map(|line| async move {
889 match line {
890 Ok(line) => {
891 let line = line
892 .strip_prefix("data: ")
893 .or_else(|| line.strip_prefix("data:"))?;
894 if line == "[DONE]" || line.is_empty() {
895 None
896 } else {
897 match serde_json::from_str::<StreamEvent>(line) {
898 Ok(event) => Some(Ok(event)),
899 Err(error) => {
900 log::error!(
901 "Failed to parse OpenAI responses stream event: `{}`\nResponse: `{}`",
902 error,
903 line,
904 );
905 Some(Err(anyhow!(error)))
906 }
907 }
908 }
909 }
910 Err(error) => Some(Err(anyhow!(error))),
911 }
912 })
913 .boxed())
914 } else {
915 let mut body = String::new();
916 response
917 .body_mut()
918 .read_to_string(&mut body)
919 .await
920 .map_err(|e| RequestError::Other(e.into()))?;
921
922 match serde_json::from_str::<ResponseSummary>(&body) {
923 Ok(response_summary) => {
924 let events = vec![
925 StreamEvent::Created {
926 response: response_summary.clone(),
927 },
928 StreamEvent::InProgress {
929 response: response_summary.clone(),
930 },
931 ];
932
933 let mut all_events = events;
934 for (output_index, item) in response_summary.output.iter().enumerate() {
935 all_events.push(StreamEvent::OutputItemAdded {
936 output_index,
937 sequence_number: None,
938 item: item.clone(),
939 });
940
941 match item {
942 ResponseOutputItem::Message(message) => {
943 for content_item in &message.content {
944 if let Some(text) = content_item.get("text") {
945 if let Some(text_str) = text.as_str() {
946 if let Some(ref item_id) = message.id {
947 all_events.push(StreamEvent::OutputTextDelta {
948 item_id: item_id.clone(),
949 output_index,
950 content_index: None,
951 delta: text_str.to_string(),
952 });
953 }
954 }
955 }
956 }
957 }
958 ResponseOutputItem::FunctionCall(function_call) => {
959 if let Some(ref item_id) = function_call.id {
960 all_events.push(StreamEvent::FunctionCallArgumentsDone {
961 item_id: item_id.clone(),
962 output_index,
963 arguments: function_call.arguments.clone(),
964 sequence_number: None,
965 });
966 }
967 }
968 ResponseOutputItem::Unknown => {}
969 }
970
971 all_events.push(StreamEvent::OutputItemDone {
972 output_index,
973 sequence_number: None,
974 item: item.clone(),
975 });
976 }
977
978 all_events.push(StreamEvent::Completed {
979 response: response_summary,
980 });
981
982 Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed())
983 }
984 Err(error) => {
985 log::error!(
986 "Failed to parse OpenAI non-streaming response: `{}`\nResponse: `{}`",
987 error,
988 body,
989 );
990 Err(RequestError::Other(anyhow!(error)))
991 }
992 }
993 }
994 } else {
995 let mut body = String::new();
996 response
997 .body_mut()
998 .read_to_string(&mut body)
999 .await
1000 .map_err(|e| RequestError::Other(e.into()))?;
1001
1002 Err(RequestError::HttpResponseError {
1003 provider: provider_name.to_owned(),
1004 status_code: response.status(),
1005 body,
1006 headers: response.headers().clone(),
1007 })
1008 }
1009 }
1010}