1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6pub use settings::OpenAiReasoningEffort as ReasoningEffort;
7use std::{convert::TryFrom, future::Future};
8use strum::EnumIter;
9
10pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
11
12fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
13 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
14}
15
16#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
17#[serde(rename_all = "lowercase")]
18pub enum Role {
19 User,
20 Assistant,
21 System,
22 Tool,
23}
24
25impl TryFrom<String> for Role {
26 type Error = anyhow::Error;
27
28 fn try_from(value: String) -> Result<Self> {
29 match value.as_str() {
30 "user" => Ok(Self::User),
31 "assistant" => Ok(Self::Assistant),
32 "system" => Ok(Self::System),
33 "tool" => Ok(Self::Tool),
34 _ => anyhow::bail!("invalid role '{value}'"),
35 }
36 }
37}
38
39impl From<Role> for String {
40 fn from(val: Role) -> Self {
41 match val {
42 Role::User => "user".to_owned(),
43 Role::Assistant => "assistant".to_owned(),
44 Role::System => "system".to_owned(),
45 Role::Tool => "tool".to_owned(),
46 }
47 }
48}
49
50#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
51#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
52pub enum Model {
53 #[serde(rename = "gpt-3.5-turbo")]
54 ThreePointFiveTurbo,
55 #[serde(rename = "gpt-4")]
56 Four,
57 #[serde(rename = "gpt-4-turbo")]
58 FourTurbo,
59 #[serde(rename = "gpt-4o")]
60 #[default]
61 FourOmni,
62 #[serde(rename = "gpt-4o-mini")]
63 FourOmniMini,
64 #[serde(rename = "gpt-4.1")]
65 FourPointOne,
66 #[serde(rename = "gpt-4.1-mini")]
67 FourPointOneMini,
68 #[serde(rename = "gpt-4.1-nano")]
69 FourPointOneNano,
70 #[serde(rename = "o1")]
71 O1,
72 #[serde(rename = "o3-mini")]
73 O3Mini,
74 #[serde(rename = "o3")]
75 O3,
76 #[serde(rename = "o4-mini")]
77 O4Mini,
78 #[serde(rename = "gpt-5")]
79 Five,
80 #[serde(rename = "gpt-5-mini")]
81 FiveMini,
82 #[serde(rename = "gpt-5-nano")]
83 FiveNano,
84
85 #[serde(rename = "custom")]
86 Custom {
87 name: String,
88 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
89 display_name: Option<String>,
90 max_tokens: u64,
91 max_output_tokens: Option<u64>,
92 max_completion_tokens: Option<u64>,
93 reasoning_effort: Option<ReasoningEffort>,
94 },
95}
96
97impl Model {
98 pub const fn default_fast() -> Self {
99 // TODO: Replace with FiveMini since all other models are deprecated
100 Self::FourPointOneMini
101 }
102
103 pub fn from_id(id: &str) -> Result<Self> {
104 match id {
105 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
106 "gpt-4" => Ok(Self::Four),
107 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
108 "gpt-4o" => Ok(Self::FourOmni),
109 "gpt-4o-mini" => Ok(Self::FourOmniMini),
110 "gpt-4.1" => Ok(Self::FourPointOne),
111 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
112 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
113 "o1" => Ok(Self::O1),
114 "o3-mini" => Ok(Self::O3Mini),
115 "o3" => Ok(Self::O3),
116 "o4-mini" => Ok(Self::O4Mini),
117 "gpt-5" => Ok(Self::Five),
118 "gpt-5-mini" => Ok(Self::FiveMini),
119 "gpt-5-nano" => Ok(Self::FiveNano),
120 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
121 }
122 }
123
124 pub fn id(&self) -> &str {
125 match self {
126 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
127 Self::Four => "gpt-4",
128 Self::FourTurbo => "gpt-4-turbo",
129 Self::FourOmni => "gpt-4o",
130 Self::FourOmniMini => "gpt-4o-mini",
131 Self::FourPointOne => "gpt-4.1",
132 Self::FourPointOneMini => "gpt-4.1-mini",
133 Self::FourPointOneNano => "gpt-4.1-nano",
134 Self::O1 => "o1",
135 Self::O3Mini => "o3-mini",
136 Self::O3 => "o3",
137 Self::O4Mini => "o4-mini",
138 Self::Five => "gpt-5",
139 Self::FiveMini => "gpt-5-mini",
140 Self::FiveNano => "gpt-5-nano",
141 Self::Custom { name, .. } => name,
142 }
143 }
144
145 pub fn display_name(&self) -> &str {
146 match self {
147 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
148 Self::Four => "gpt-4",
149 Self::FourTurbo => "gpt-4-turbo",
150 Self::FourOmni => "gpt-4o",
151 Self::FourOmniMini => "gpt-4o-mini",
152 Self::FourPointOne => "gpt-4.1",
153 Self::FourPointOneMini => "gpt-4.1-mini",
154 Self::FourPointOneNano => "gpt-4.1-nano",
155 Self::O1 => "o1",
156 Self::O3Mini => "o3-mini",
157 Self::O3 => "o3",
158 Self::O4Mini => "o4-mini",
159 Self::Five => "gpt-5",
160 Self::FiveMini => "gpt-5-mini",
161 Self::FiveNano => "gpt-5-nano",
162 Self::Custom {
163 name, display_name, ..
164 } => display_name.as_ref().unwrap_or(name),
165 }
166 }
167
168 pub const fn max_token_count(&self) -> u64 {
169 match self {
170 Self::ThreePointFiveTurbo => 16_385,
171 Self::Four => 8_192,
172 Self::FourTurbo => 128_000,
173 Self::FourOmni => 128_000,
174 Self::FourOmniMini => 128_000,
175 Self::FourPointOne => 1_047_576,
176 Self::FourPointOneMini => 1_047_576,
177 Self::FourPointOneNano => 1_047_576,
178 Self::O1 => 200_000,
179 Self::O3Mini => 200_000,
180 Self::O3 => 200_000,
181 Self::O4Mini => 200_000,
182 Self::Five => 272_000,
183 Self::FiveMini => 272_000,
184 Self::FiveNano => 272_000,
185 Self::Custom { max_tokens, .. } => *max_tokens,
186 }
187 }
188
189 pub const fn max_output_tokens(&self) -> Option<u64> {
190 match self {
191 Self::Custom {
192 max_output_tokens, ..
193 } => *max_output_tokens,
194 Self::ThreePointFiveTurbo => Some(4_096),
195 Self::Four => Some(8_192),
196 Self::FourTurbo => Some(4_096),
197 Self::FourOmni => Some(16_384),
198 Self::FourOmniMini => Some(16_384),
199 Self::FourPointOne => Some(32_768),
200 Self::FourPointOneMini => Some(32_768),
201 Self::FourPointOneNano => Some(32_768),
202 Self::O1 => Some(100_000),
203 Self::O3Mini => Some(100_000),
204 Self::O3 => Some(100_000),
205 Self::O4Mini => Some(100_000),
206 Self::Five => Some(128_000),
207 Self::FiveMini => Some(128_000),
208 Self::FiveNano => Some(128_000),
209 }
210 }
211
212 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
213 match self {
214 Self::Custom {
215 reasoning_effort, ..
216 } => reasoning_effort.to_owned(),
217 _ => None,
218 }
219 }
220
221 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
222 ///
223 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
224 pub const fn supports_parallel_tool_calls(&self) -> bool {
225 match self {
226 Self::ThreePointFiveTurbo
227 | Self::Four
228 | Self::FourTurbo
229 | Self::FourOmni
230 | Self::FourOmniMini
231 | Self::FourPointOne
232 | Self::FourPointOneMini
233 | Self::FourPointOneNano
234 | Self::Five
235 | Self::FiveMini
236 | Self::FiveNano => true,
237 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
238 }
239 }
240
241 /// Returns whether the given model supports the `prompt_cache_key` parameter.
242 ///
243 /// If the model does not support the parameter, do not pass it up.
244 pub const fn supports_prompt_cache_key(&self) -> bool {
245 true
246 }
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250pub struct Request {
251 pub model: String,
252 pub messages: Vec<RequestMessage>,
253 pub stream: bool,
254 #[serde(default, skip_serializing_if = "Option::is_none")]
255 pub max_completion_tokens: Option<u64>,
256 #[serde(default, skip_serializing_if = "Vec::is_empty")]
257 pub stop: Vec<String>,
258 pub temperature: f32,
259 #[serde(default, skip_serializing_if = "Option::is_none")]
260 pub tool_choice: Option<ToolChoice>,
261 /// Whether to enable parallel function calling during tool use.
262 #[serde(default, skip_serializing_if = "Option::is_none")]
263 pub parallel_tool_calls: Option<bool>,
264 #[serde(default, skip_serializing_if = "Vec::is_empty")]
265 pub tools: Vec<ToolDefinition>,
266 #[serde(default, skip_serializing_if = "Option::is_none")]
267 pub prompt_cache_key: Option<String>,
268 #[serde(default, skip_serializing_if = "Option::is_none")]
269 pub reasoning_effort: Option<ReasoningEffort>,
270}
271
272#[derive(Debug, Serialize, Deserialize)]
273#[serde(rename_all = "lowercase")]
274pub enum ToolChoice {
275 Auto,
276 Required,
277 None,
278 #[serde(untagged)]
279 Other(ToolDefinition),
280}
281
282#[derive(Clone, Deserialize, Serialize, Debug)]
283#[serde(tag = "type", rename_all = "snake_case")]
284pub enum ToolDefinition {
285 #[allow(dead_code)]
286 Function { function: FunctionDefinition },
287}
288
289#[derive(Clone, Debug, Serialize, Deserialize)]
290pub struct FunctionDefinition {
291 pub name: String,
292 pub description: Option<String>,
293 pub parameters: Option<Value>,
294}
295
296#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
297#[serde(tag = "role", rename_all = "lowercase")]
298pub enum RequestMessage {
299 Assistant {
300 content: Option<MessageContent>,
301 #[serde(default, skip_serializing_if = "Vec::is_empty")]
302 tool_calls: Vec<ToolCall>,
303 },
304 User {
305 content: MessageContent,
306 },
307 System {
308 content: MessageContent,
309 },
310 Tool {
311 content: MessageContent,
312 tool_call_id: String,
313 },
314}
315
316#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
317#[serde(untagged)]
318pub enum MessageContent {
319 Plain(String),
320 Multipart(Vec<MessagePart>),
321}
322
323impl MessageContent {
324 pub const fn empty() -> Self {
325 MessageContent::Multipart(vec![])
326 }
327
328 pub fn push_part(&mut self, part: MessagePart) {
329 match self {
330 MessageContent::Plain(text) => {
331 *self =
332 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
333 }
334 MessageContent::Multipart(parts) if parts.is_empty() => match part {
335 MessagePart::Text { text } => *self = MessageContent::Plain(text),
336 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
337 },
338 MessageContent::Multipart(parts) => parts.push(part),
339 }
340 }
341}
342
343impl From<Vec<MessagePart>> for MessageContent {
344 fn from(mut parts: Vec<MessagePart>) -> Self {
345 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
346 MessageContent::Plain(std::mem::take(text))
347 } else {
348 MessageContent::Multipart(parts)
349 }
350 }
351}
352
353#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
354#[serde(tag = "type")]
355pub enum MessagePart {
356 #[serde(rename = "text")]
357 Text { text: String },
358 #[serde(rename = "image_url")]
359 Image { image_url: ImageUrl },
360}
361
362#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
363pub struct ImageUrl {
364 pub url: String,
365 #[serde(skip_serializing_if = "Option::is_none")]
366 pub detail: Option<String>,
367}
368
369#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
370pub struct ToolCall {
371 pub id: String,
372 #[serde(flatten)]
373 pub content: ToolCallContent,
374}
375
376#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
377#[serde(tag = "type", rename_all = "lowercase")]
378pub enum ToolCallContent {
379 Function { function: FunctionContent },
380}
381
382#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
383pub struct FunctionContent {
384 pub name: String,
385 pub arguments: String,
386}
387
388#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
389pub struct ResponseMessageDelta {
390 pub role: Option<Role>,
391 pub content: Option<String>,
392 #[serde(default, skip_serializing_if = "is_none_or_empty")]
393 pub tool_calls: Option<Vec<ToolCallChunk>>,
394}
395
396#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
397pub struct ToolCallChunk {
398 pub index: usize,
399 pub id: Option<String>,
400
401 // There is also an optional `type` field that would determine if a
402 // function is there. Sometimes this streams in with the `function` before
403 // it streams in the `type`
404 pub function: Option<FunctionChunk>,
405}
406
407#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
408pub struct FunctionChunk {
409 pub name: Option<String>,
410 pub arguments: Option<String>,
411}
412
413#[derive(Serialize, Deserialize, Debug)]
414pub struct Usage {
415 pub prompt_tokens: u64,
416 pub completion_tokens: u64,
417 pub total_tokens: u64,
418}
419
420#[derive(Serialize, Deserialize, Debug)]
421pub struct ChoiceDelta {
422 pub index: u32,
423 pub delta: ResponseMessageDelta,
424 pub finish_reason: Option<String>,
425}
426
427#[derive(Serialize, Deserialize, Debug)]
428pub struct OpenAiError {
429 message: String,
430}
431
432#[derive(Serialize, Deserialize, Debug)]
433#[serde(untagged)]
434pub enum ResponseStreamResult {
435 Ok(ResponseStreamEvent),
436 Err { error: OpenAiError },
437}
438
439#[derive(Serialize, Deserialize, Debug)]
440pub struct ResponseStreamEvent {
441 pub choices: Vec<ChoiceDelta>,
442 pub usage: Option<Usage>,
443}
444
445pub async fn stream_completion(
446 client: &dyn HttpClient,
447 api_url: &str,
448 api_key: &str,
449 request: Request,
450) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
451 let uri = format!("{api_url}/chat/completions");
452 let request_builder = HttpRequest::builder()
453 .method(Method::POST)
454 .uri(uri)
455 .header("Content-Type", "application/json")
456 .header("Authorization", format!("Bearer {}", api_key.trim()));
457
458 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
459 let mut response = client.send(request).await?;
460 if response.status().is_success() {
461 let reader = BufReader::new(response.into_body());
462 Ok(reader
463 .lines()
464 .filter_map(|line| async move {
465 match line {
466 Ok(line) => {
467 let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
468 if line == "[DONE]" {
469 None
470 } else {
471 match serde_json::from_str(line) {
472 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
473 Ok(ResponseStreamResult::Err { error }) => {
474 Some(Err(anyhow!(error.message)))
475 }
476 Err(error) => {
477 log::error!(
478 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
479 Response: `{}`",
480 error,
481 line,
482 );
483 Some(Err(anyhow!(error)))
484 }
485 }
486 }
487 }
488 Err(error) => Some(Err(anyhow!(error))),
489 }
490 })
491 .boxed())
492 } else {
493 let mut body = String::new();
494 response.body_mut().read_to_string(&mut body).await?;
495
496 #[derive(Deserialize)]
497 struct OpenAiResponse {
498 error: OpenAiError,
499 }
500
501 match serde_json::from_str::<OpenAiResponse>(&body) {
502 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
503 "API request to {} failed: {}",
504 api_url,
505 response.error.message,
506 )),
507
508 _ => anyhow::bail!(
509 "API request to {} failed with status {}: {}",
510 api_url,
511 response.status(),
512 body,
513 ),
514 }
515 }
516}
517
518#[derive(Copy, Clone, Serialize, Deserialize)]
519pub enum OpenAiEmbeddingModel {
520 #[serde(rename = "text-embedding-3-small")]
521 TextEmbedding3Small,
522 #[serde(rename = "text-embedding-3-large")]
523 TextEmbedding3Large,
524}
525
526#[derive(Serialize)]
527struct OpenAiEmbeddingRequest<'a> {
528 model: OpenAiEmbeddingModel,
529 input: Vec<&'a str>,
530}
531
532#[derive(Deserialize)]
533pub struct OpenAiEmbeddingResponse {
534 pub data: Vec<OpenAiEmbedding>,
535}
536
537#[derive(Deserialize)]
538pub struct OpenAiEmbedding {
539 pub embedding: Vec<f32>,
540}
541
542pub fn embed<'a>(
543 client: &dyn HttpClient,
544 api_url: &str,
545 api_key: &str,
546 model: OpenAiEmbeddingModel,
547 texts: impl IntoIterator<Item = &'a str>,
548) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
549 let uri = format!("{api_url}/embeddings");
550
551 let request = OpenAiEmbeddingRequest {
552 model,
553 input: texts.into_iter().collect(),
554 };
555 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
556 let request = HttpRequest::builder()
557 .method(Method::POST)
558 .uri(uri)
559 .header("Content-Type", "application/json")
560 .header("Authorization", format!("Bearer {}", api_key.trim()))
561 .body(body)
562 .map(|request| client.send(request));
563
564 async move {
565 let mut response = request?.await?;
566 let mut body = String::new();
567 response.body_mut().read_to_string(&mut body).await?;
568
569 anyhow::ensure!(
570 response.status().is_success(),
571 "error during embedding, status: {:?}, body: {:?}",
572 response.status(),
573 body
574 );
575 let response: OpenAiEmbeddingResponse =
576 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
577 Ok(response)
578 }
579}