1pub mod responses;
2
3use anyhow::{Context as _, Result, anyhow};
4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
5use http_client::{
6 AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
7 http::{HeaderMap, HeaderValue},
8};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11pub use settings::OpenAiReasoningEffort as ReasoningEffort;
12use std::{convert::TryFrom, future::Future};
13use strum::EnumIter;
14use thiserror::Error;
15
16pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
17
18fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
19 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
20}
21
22#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
23#[serde(rename_all = "lowercase")]
24pub enum Role {
25 User,
26 Assistant,
27 System,
28 Tool,
29}
30
31impl TryFrom<String> for Role {
32 type Error = anyhow::Error;
33
34 fn try_from(value: String) -> Result<Self> {
35 match value.as_str() {
36 "user" => Ok(Self::User),
37 "assistant" => Ok(Self::Assistant),
38 "system" => Ok(Self::System),
39 "tool" => Ok(Self::Tool),
40 _ => anyhow::bail!("invalid role '{value}'"),
41 }
42 }
43}
44
45impl From<Role> for String {
46 fn from(val: Role) -> Self {
47 match val {
48 Role::User => "user".to_owned(),
49 Role::Assistant => "assistant".to_owned(),
50 Role::System => "system".to_owned(),
51 Role::Tool => "tool".to_owned(),
52 }
53 }
54}
55
56#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
57#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
58pub enum Model {
59 #[serde(rename = "gpt-3.5-turbo")]
60 ThreePointFiveTurbo,
61 #[serde(rename = "gpt-4")]
62 Four,
63 #[serde(rename = "gpt-4-turbo")]
64 FourTurbo,
65 #[serde(rename = "gpt-4o")]
66 #[default]
67 FourOmni,
68 #[serde(rename = "gpt-4o-mini")]
69 FourOmniMini,
70 #[serde(rename = "gpt-4.1")]
71 FourPointOne,
72 #[serde(rename = "gpt-4.1-mini")]
73 FourPointOneMini,
74 #[serde(rename = "gpt-4.1-nano")]
75 FourPointOneNano,
76 #[serde(rename = "o1")]
77 O1,
78 #[serde(rename = "o3-mini")]
79 O3Mini,
80 #[serde(rename = "o3")]
81 O3,
82 #[serde(rename = "o4-mini")]
83 O4Mini,
84 #[serde(rename = "gpt-5")]
85 Five,
86 #[serde(rename = "gpt-5-codex")]
87 FiveCodex,
88 #[serde(rename = "gpt-5-mini")]
89 FiveMini,
90 #[serde(rename = "gpt-5-nano")]
91 FiveNano,
92 #[serde(rename = "gpt-5.1")]
93 FivePointOne,
94 #[serde(rename = "gpt-5.2")]
95 FivePointTwo,
96 #[serde(rename = "custom")]
97 Custom {
98 name: String,
99 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
100 display_name: Option<String>,
101 max_tokens: u64,
102 max_output_tokens: Option<u64>,
103 max_completion_tokens: Option<u64>,
104 reasoning_effort: Option<ReasoningEffort>,
105 #[serde(default = "default_supports_chat_completions")]
106 supports_chat_completions: bool,
107 },
108}
109
110const fn default_supports_chat_completions() -> bool {
111 true
112}
113
114impl Model {
115 pub fn default_fast() -> Self {
116 // TODO: Replace with FiveMini since all other models are deprecated
117 Self::FourPointOneMini
118 }
119
120 pub fn from_id(id: &str) -> Result<Self> {
121 match id {
122 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
123 "gpt-4" => Ok(Self::Four),
124 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
125 "gpt-4o" => Ok(Self::FourOmni),
126 "gpt-4o-mini" => Ok(Self::FourOmniMini),
127 "gpt-4.1" => Ok(Self::FourPointOne),
128 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
129 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
130 "o1" => Ok(Self::O1),
131 "o3-mini" => Ok(Self::O3Mini),
132 "o3" => Ok(Self::O3),
133 "o4-mini" => Ok(Self::O4Mini),
134 "gpt-5" => Ok(Self::Five),
135 "gpt-5-codex" => Ok(Self::FiveCodex),
136 "gpt-5-mini" => Ok(Self::FiveMini),
137 "gpt-5-nano" => Ok(Self::FiveNano),
138 "gpt-5.1" => Ok(Self::FivePointOne),
139 "gpt-5.2" => Ok(Self::FivePointTwo),
140 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
141 }
142 }
143
144 pub fn id(&self) -> &str {
145 match self {
146 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
147 Self::Four => "gpt-4",
148 Self::FourTurbo => "gpt-4-turbo",
149 Self::FourOmni => "gpt-4o",
150 Self::FourOmniMini => "gpt-4o-mini",
151 Self::FourPointOne => "gpt-4.1",
152 Self::FourPointOneMini => "gpt-4.1-mini",
153 Self::FourPointOneNano => "gpt-4.1-nano",
154 Self::O1 => "o1",
155 Self::O3Mini => "o3-mini",
156 Self::O3 => "o3",
157 Self::O4Mini => "o4-mini",
158 Self::Five => "gpt-5",
159 Self::FiveCodex => "gpt-5-codex",
160 Self::FiveMini => "gpt-5-mini",
161 Self::FiveNano => "gpt-5-nano",
162 Self::FivePointOne => "gpt-5.1",
163 Self::FivePointTwo => "gpt-5.2",
164 Self::Custom { name, .. } => name,
165 }
166 }
167
168 pub fn display_name(&self) -> &str {
169 match self {
170 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
171 Self::Four => "gpt-4",
172 Self::FourTurbo => "gpt-4-turbo",
173 Self::FourOmni => "gpt-4o",
174 Self::FourOmniMini => "gpt-4o-mini",
175 Self::FourPointOne => "gpt-4.1",
176 Self::FourPointOneMini => "gpt-4.1-mini",
177 Self::FourPointOneNano => "gpt-4.1-nano",
178 Self::O1 => "o1",
179 Self::O3Mini => "o3-mini",
180 Self::O3 => "o3",
181 Self::O4Mini => "o4-mini",
182 Self::Five => "gpt-5",
183 Self::FiveCodex => "gpt-5-codex",
184 Self::FiveMini => "gpt-5-mini",
185 Self::FiveNano => "gpt-5-nano",
186 Self::FivePointOne => "gpt-5.1",
187 Self::FivePointTwo => "gpt-5.2",
188 Self::Custom {
189 name, display_name, ..
190 } => display_name.as_ref().unwrap_or(name),
191 }
192 }
193
194 pub fn max_token_count(&self) -> u64 {
195 match self {
196 Self::ThreePointFiveTurbo => 16_385,
197 Self::Four => 8_192,
198 Self::FourTurbo => 128_000,
199 Self::FourOmni => 128_000,
200 Self::FourOmniMini => 128_000,
201 Self::FourPointOne => 1_047_576,
202 Self::FourPointOneMini => 1_047_576,
203 Self::FourPointOneNano => 1_047_576,
204 Self::O1 => 200_000,
205 Self::O3Mini => 200_000,
206 Self::O3 => 200_000,
207 Self::O4Mini => 200_000,
208 Self::Five => 272_000,
209 Self::FiveCodex => 272_000,
210 Self::FiveMini => 272_000,
211 Self::FiveNano => 272_000,
212 Self::FivePointOne => 400_000,
213 Self::FivePointTwo => 400_000,
214 Self::Custom { max_tokens, .. } => *max_tokens,
215 }
216 }
217
218 pub fn max_output_tokens(&self) -> Option<u64> {
219 match self {
220 Self::Custom {
221 max_output_tokens, ..
222 } => *max_output_tokens,
223 Self::ThreePointFiveTurbo => Some(4_096),
224 Self::Four => Some(8_192),
225 Self::FourTurbo => Some(4_096),
226 Self::FourOmni => Some(16_384),
227 Self::FourOmniMini => Some(16_384),
228 Self::FourPointOne => Some(32_768),
229 Self::FourPointOneMini => Some(32_768),
230 Self::FourPointOneNano => Some(32_768),
231 Self::O1 => Some(100_000),
232 Self::O3Mini => Some(100_000),
233 Self::O3 => Some(100_000),
234 Self::O4Mini => Some(100_000),
235 Self::Five => Some(128_000),
236 Self::FiveCodex => Some(128_000),
237 Self::FiveMini => Some(128_000),
238 Self::FiveNano => Some(128_000),
239 Self::FivePointOne => Some(128_000),
240 Self::FivePointTwo => Some(128_000),
241 }
242 }
243
244 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
245 match self {
246 Self::Custom {
247 reasoning_effort, ..
248 } => reasoning_effort.to_owned(),
249 _ => None,
250 }
251 }
252
253 pub fn supports_chat_completions(&self) -> bool {
254 match self {
255 Self::Custom {
256 supports_chat_completions,
257 ..
258 } => *supports_chat_completions,
259 Self::FiveCodex => false,
260 _ => true,
261 }
262 }
263
264 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
265 ///
266 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
267 pub fn supports_parallel_tool_calls(&self) -> bool {
268 match self {
269 Self::ThreePointFiveTurbo
270 | Self::Four
271 | Self::FourTurbo
272 | Self::FourOmni
273 | Self::FourOmniMini
274 | Self::FourPointOne
275 | Self::FourPointOneMini
276 | Self::FourPointOneNano
277 | Self::Five
278 | Self::FiveCodex
279 | Self::FiveMini
280 | Self::FivePointOne
281 | Self::FivePointTwo
282 | Self::FiveNano => true,
283 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
284 }
285 }
286
287 /// Returns whether the given model supports the `prompt_cache_key` parameter.
288 ///
289 /// If the model does not support the parameter, do not pass it up.
290 pub fn supports_prompt_cache_key(&self) -> bool {
291 true
292 }
293}
294
295#[derive(Debug, Serialize, Deserialize)]
296pub struct Request {
297 pub model: String,
298 pub messages: Vec<RequestMessage>,
299 pub stream: bool,
300 #[serde(default, skip_serializing_if = "Option::is_none")]
301 pub max_completion_tokens: Option<u64>,
302 #[serde(default, skip_serializing_if = "Vec::is_empty")]
303 pub stop: Vec<String>,
304 #[serde(default, skip_serializing_if = "Option::is_none")]
305 pub temperature: Option<f32>,
306 #[serde(default, skip_serializing_if = "Option::is_none")]
307 pub tool_choice: Option<ToolChoice>,
308 /// Whether to enable parallel function calling during tool use.
309 #[serde(default, skip_serializing_if = "Option::is_none")]
310 pub parallel_tool_calls: Option<bool>,
311 #[serde(default, skip_serializing_if = "Vec::is_empty")]
312 pub tools: Vec<ToolDefinition>,
313 #[serde(default, skip_serializing_if = "Option::is_none")]
314 pub prompt_cache_key: Option<String>,
315 #[serde(default, skip_serializing_if = "Option::is_none")]
316 pub reasoning_effort: Option<ReasoningEffort>,
317}
318
319#[derive(Debug, Serialize, Deserialize)]
320#[serde(rename_all = "lowercase")]
321pub enum ToolChoice {
322 Auto,
323 Required,
324 None,
325 #[serde(untagged)]
326 Other(ToolDefinition),
327}
328
329#[derive(Clone, Deserialize, Serialize, Debug)]
330#[serde(tag = "type", rename_all = "snake_case")]
331pub enum ToolDefinition {
332 #[allow(dead_code)]
333 Function { function: FunctionDefinition },
334}
335
336#[derive(Clone, Debug, Serialize, Deserialize)]
337pub struct FunctionDefinition {
338 pub name: String,
339 pub description: Option<String>,
340 pub parameters: Option<Value>,
341}
342
343#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
344#[serde(tag = "role", rename_all = "lowercase")]
345pub enum RequestMessage {
346 Assistant {
347 content: Option<MessageContent>,
348 #[serde(default, skip_serializing_if = "Vec::is_empty")]
349 tool_calls: Vec<ToolCall>,
350 },
351 User {
352 content: MessageContent,
353 },
354 System {
355 content: MessageContent,
356 },
357 Tool {
358 content: MessageContent,
359 tool_call_id: String,
360 },
361}
362
363#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
364#[serde(untagged)]
365pub enum MessageContent {
366 Plain(String),
367 Multipart(Vec<MessagePart>),
368}
369
370impl MessageContent {
371 pub fn empty() -> Self {
372 MessageContent::Multipart(vec![])
373 }
374
375 pub fn push_part(&mut self, part: MessagePart) {
376 match self {
377 MessageContent::Plain(text) => {
378 *self =
379 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
380 }
381 MessageContent::Multipart(parts) if parts.is_empty() => match part {
382 MessagePart::Text { text } => *self = MessageContent::Plain(text),
383 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
384 },
385 MessageContent::Multipart(parts) => parts.push(part),
386 }
387 }
388}
389
390impl From<Vec<MessagePart>> for MessageContent {
391 fn from(mut parts: Vec<MessagePart>) -> Self {
392 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
393 MessageContent::Plain(std::mem::take(text))
394 } else {
395 MessageContent::Multipart(parts)
396 }
397 }
398}
399
400#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
401#[serde(tag = "type")]
402pub enum MessagePart {
403 #[serde(rename = "text")]
404 Text { text: String },
405 #[serde(rename = "image_url")]
406 Image { image_url: ImageUrl },
407}
408
409#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
410pub struct ImageUrl {
411 pub url: String,
412 #[serde(skip_serializing_if = "Option::is_none")]
413 pub detail: Option<String>,
414}
415
416#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
417pub struct ToolCall {
418 pub id: String,
419 #[serde(flatten)]
420 pub content: ToolCallContent,
421}
422
423#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
424#[serde(tag = "type", rename_all = "lowercase")]
425pub enum ToolCallContent {
426 Function { function: FunctionContent },
427}
428
429#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
430pub struct FunctionContent {
431 pub name: String,
432 pub arguments: String,
433}
434
435#[derive(Clone, Serialize, Deserialize, Debug)]
436pub struct Response {
437 pub id: String,
438 pub object: String,
439 pub created: u64,
440 pub model: String,
441 pub choices: Vec<Choice>,
442 pub usage: Usage,
443}
444
445#[derive(Clone, Serialize, Deserialize, Debug)]
446pub struct Choice {
447 pub index: u32,
448 pub message: RequestMessage,
449 pub finish_reason: Option<String>,
450}
451
452#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
453pub struct ResponseMessageDelta {
454 pub role: Option<Role>,
455 pub content: Option<String>,
456 #[serde(default, skip_serializing_if = "is_none_or_empty")]
457 pub tool_calls: Option<Vec<ToolCallChunk>>,
458}
459
460#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
461pub struct ToolCallChunk {
462 pub index: usize,
463 pub id: Option<String>,
464
465 // There is also an optional `type` field that would determine if a
466 // function is there. Sometimes this streams in with the `function` before
467 // it streams in the `type`
468 pub function: Option<FunctionChunk>,
469}
470
471#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
472pub struct FunctionChunk {
473 pub name: Option<String>,
474 pub arguments: Option<String>,
475}
476
477#[derive(Clone, Serialize, Deserialize, Debug)]
478pub struct Usage {
479 pub prompt_tokens: u64,
480 pub completion_tokens: u64,
481 pub total_tokens: u64,
482}
483
484#[derive(Serialize, Deserialize, Debug)]
485pub struct ChoiceDelta {
486 pub index: u32,
487 pub delta: Option<ResponseMessageDelta>,
488 pub finish_reason: Option<String>,
489}
490
491#[derive(Error, Debug)]
492pub enum RequestError {
493 #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
494 HttpResponseError {
495 provider: String,
496 status_code: StatusCode,
497 body: String,
498 headers: HeaderMap<HeaderValue>,
499 },
500 #[error(transparent)]
501 Other(#[from] anyhow::Error),
502}
503
504#[derive(Serialize, Deserialize, Debug)]
505pub struct ResponseStreamError {
506 message: String,
507}
508
509#[derive(Serialize, Deserialize, Debug)]
510#[serde(untagged)]
511pub enum ResponseStreamResult {
512 Ok(ResponseStreamEvent),
513 Err { error: ResponseStreamError },
514}
515
516#[derive(Serialize, Deserialize, Debug)]
517pub struct ResponseStreamEvent {
518 pub choices: Vec<ChoiceDelta>,
519 pub usage: Option<Usage>,
520}
521
522pub async fn stream_completion(
523 client: &dyn HttpClient,
524 provider_name: &str,
525 api_url: &str,
526 api_key: &str,
527 request: Request,
528) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
529 let uri = format!("{api_url}/chat/completions");
530 let request_builder = HttpRequest::builder()
531 .method(Method::POST)
532 .uri(uri)
533 .header("Content-Type", "application/json")
534 .header("Authorization", format!("Bearer {}", api_key.trim()));
535
536 let request = request_builder
537 .body(AsyncBody::from(
538 serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
539 ))
540 .map_err(|e| RequestError::Other(e.into()))?;
541
542 let mut response = client.send(request).await?;
543 if response.status().is_success() {
544 let reader = BufReader::new(response.into_body());
545 Ok(reader
546 .lines()
547 .filter_map(|line| async move {
548 match line {
549 Ok(line) => {
550 let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
551 if line == "[DONE]" {
552 None
553 } else {
554 match serde_json::from_str(line) {
555 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
556 Ok(ResponseStreamResult::Err { error }) => {
557 Some(Err(anyhow!(error.message)))
558 }
559 Err(error) => {
560 log::error!(
561 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
562 Response: `{}`",
563 error,
564 line,
565 );
566 Some(Err(anyhow!(error)))
567 }
568 }
569 }
570 }
571 Err(error) => Some(Err(anyhow!(error))),
572 }
573 })
574 .boxed())
575 } else {
576 let mut body = String::new();
577 response
578 .body_mut()
579 .read_to_string(&mut body)
580 .await
581 .map_err(|e| RequestError::Other(e.into()))?;
582
583 Err(RequestError::HttpResponseError {
584 provider: provider_name.to_owned(),
585 status_code: response.status(),
586 body,
587 headers: response.headers().clone(),
588 })
589 }
590}
591
592#[derive(Copy, Clone, Serialize, Deserialize)]
593pub enum OpenAiEmbeddingModel {
594 #[serde(rename = "text-embedding-3-small")]
595 TextEmbedding3Small,
596 #[serde(rename = "text-embedding-3-large")]
597 TextEmbedding3Large,
598}
599
600#[derive(Serialize)]
601struct OpenAiEmbeddingRequest<'a> {
602 model: OpenAiEmbeddingModel,
603 input: Vec<&'a str>,
604}
605
606#[derive(Deserialize)]
607pub struct OpenAiEmbeddingResponse {
608 pub data: Vec<OpenAiEmbedding>,
609}
610
611#[derive(Deserialize)]
612pub struct OpenAiEmbedding {
613 pub embedding: Vec<f32>,
614}
615
616pub fn embed<'a>(
617 client: &dyn HttpClient,
618 api_url: &str,
619 api_key: &str,
620 model: OpenAiEmbeddingModel,
621 texts: impl IntoIterator<Item = &'a str>,
622) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
623 let uri = format!("{api_url}/embeddings");
624
625 let request = OpenAiEmbeddingRequest {
626 model,
627 input: texts.into_iter().collect(),
628 };
629 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
630 let request = HttpRequest::builder()
631 .method(Method::POST)
632 .uri(uri)
633 .header("Content-Type", "application/json")
634 .header("Authorization", format!("Bearer {}", api_key.trim()))
635 .body(body)
636 .map(|request| client.send(request));
637
638 async move {
639 let mut response = request?.await?;
640 let mut body = String::new();
641 response.body_mut().read_to_string(&mut body).await?;
642
643 anyhow::ensure!(
644 response.status().is_success(),
645 "error during embedding, status: {:?}, body: {:?}",
646 response.status(),
647 body
648 );
649 let response: OpenAiEmbeddingResponse =
650 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
651 Ok(response)
652 }
653}