1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, future::Future};
7use strum::EnumIter;
8
9pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
10
11fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
12 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
13}
14
15#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
16#[serde(rename_all = "lowercase")]
17pub enum Role {
18 User,
19 Assistant,
20 System,
21 Tool,
22}
23
24impl TryFrom<String> for Role {
25 type Error = anyhow::Error;
26
27 fn try_from(value: String) -> Result<Self> {
28 match value.as_str() {
29 "user" => Ok(Self::User),
30 "assistant" => Ok(Self::Assistant),
31 "system" => Ok(Self::System),
32 "tool" => Ok(Self::Tool),
33 _ => anyhow::bail!("invalid role '{value}'"),
34 }
35 }
36}
37
38impl From<Role> for String {
39 fn from(val: Role) -> Self {
40 match val {
41 Role::User => "user".to_owned(),
42 Role::Assistant => "assistant".to_owned(),
43 Role::System => "system".to_owned(),
44 Role::Tool => "tool".to_owned(),
45 }
46 }
47}
48
49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
50#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
51pub enum Model {
52 #[serde(rename = "gpt-3.5-turbo")]
53 ThreePointFiveTurbo,
54 #[serde(rename = "gpt-4")]
55 Four,
56 #[serde(rename = "gpt-4-turbo")]
57 FourTurbo,
58 #[serde(rename = "gpt-4o")]
59 #[default]
60 FourOmni,
61 #[serde(rename = "gpt-4o-mini")]
62 FourOmniMini,
63 #[serde(rename = "gpt-4.1")]
64 FourPointOne,
65 #[serde(rename = "gpt-4.1-mini")]
66 FourPointOneMini,
67 #[serde(rename = "gpt-4.1-nano")]
68 FourPointOneNano,
69 #[serde(rename = "o1")]
70 O1,
71 #[serde(rename = "o3-mini")]
72 O3Mini,
73 #[serde(rename = "o3")]
74 O3,
75 #[serde(rename = "o4-mini")]
76 O4Mini,
77 #[serde(rename = "gpt-5")]
78 Five,
79 #[serde(rename = "gpt-5-mini")]
80 FiveMini,
81 #[serde(rename = "gpt-5-nano")]
82 FiveNano,
83
84 #[serde(rename = "custom")]
85 Custom {
86 name: String,
87 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
88 display_name: Option<String>,
89 max_tokens: u64,
90 max_output_tokens: Option<u64>,
91 max_completion_tokens: Option<u64>,
92 reasoning_effort: Option<ReasoningEffort>,
93 },
94}
95
96impl Model {
97 pub fn default_fast() -> Self {
98 // TODO: Replace with FiveMini since all other models are deprecated
99 Self::FourPointOneMini
100 }
101
102 pub fn from_id(id: &str) -> Result<Self> {
103 match id {
104 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
105 "gpt-4" => Ok(Self::Four),
106 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
107 "gpt-4o" => Ok(Self::FourOmni),
108 "gpt-4o-mini" => Ok(Self::FourOmniMini),
109 "gpt-4.1" => Ok(Self::FourPointOne),
110 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
111 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
112 "o1" => Ok(Self::O1),
113 "o3-mini" => Ok(Self::O3Mini),
114 "o3" => Ok(Self::O3),
115 "o4-mini" => Ok(Self::O4Mini),
116 "gpt-5" => Ok(Self::Five),
117 "gpt-5-mini" => Ok(Self::FiveMini),
118 "gpt-5-nano" => Ok(Self::FiveNano),
119 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
120 }
121 }
122
123 pub fn id(&self) -> &str {
124 match self {
125 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
126 Self::Four => "gpt-4",
127 Self::FourTurbo => "gpt-4-turbo",
128 Self::FourOmni => "gpt-4o",
129 Self::FourOmniMini => "gpt-4o-mini",
130 Self::FourPointOne => "gpt-4.1",
131 Self::FourPointOneMini => "gpt-4.1-mini",
132 Self::FourPointOneNano => "gpt-4.1-nano",
133 Self::O1 => "o1",
134 Self::O3Mini => "o3-mini",
135 Self::O3 => "o3",
136 Self::O4Mini => "o4-mini",
137 Self::Five => "gpt-5",
138 Self::FiveMini => "gpt-5-mini",
139 Self::FiveNano => "gpt-5-nano",
140 Self::Custom { name, .. } => name,
141 }
142 }
143
144 pub fn display_name(&self) -> &str {
145 match self {
146 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
147 Self::Four => "gpt-4",
148 Self::FourTurbo => "gpt-4-turbo",
149 Self::FourOmni => "gpt-4o",
150 Self::FourOmniMini => "gpt-4o-mini",
151 Self::FourPointOne => "gpt-4.1",
152 Self::FourPointOneMini => "gpt-4.1-mini",
153 Self::FourPointOneNano => "gpt-4.1-nano",
154 Self::O1 => "o1",
155 Self::O3Mini => "o3-mini",
156 Self::O3 => "o3",
157 Self::O4Mini => "o4-mini",
158 Self::Five => "gpt-5",
159 Self::FiveMini => "gpt-5-mini",
160 Self::FiveNano => "gpt-5-nano",
161 Self::Custom {
162 name, display_name, ..
163 } => display_name.as_ref().unwrap_or(name),
164 }
165 }
166
167 pub fn max_token_count(&self) -> u64 {
168 match self {
169 Self::ThreePointFiveTurbo => 16_385,
170 Self::Four => 8_192,
171 Self::FourTurbo => 128_000,
172 Self::FourOmni => 128_000,
173 Self::FourOmniMini => 128_000,
174 Self::FourPointOne => 1_047_576,
175 Self::FourPointOneMini => 1_047_576,
176 Self::FourPointOneNano => 1_047_576,
177 Self::O1 => 200_000,
178 Self::O3Mini => 200_000,
179 Self::O3 => 200_000,
180 Self::O4Mini => 200_000,
181 Self::Five => 272_000,
182 Self::FiveMini => 272_000,
183 Self::FiveNano => 272_000,
184 Self::Custom { max_tokens, .. } => *max_tokens,
185 }
186 }
187
188 pub fn max_output_tokens(&self) -> Option<u64> {
189 match self {
190 Self::Custom {
191 max_output_tokens, ..
192 } => *max_output_tokens,
193 Self::ThreePointFiveTurbo => Some(4_096),
194 Self::Four => Some(8_192),
195 Self::FourTurbo => Some(4_096),
196 Self::FourOmni => Some(16_384),
197 Self::FourOmniMini => Some(16_384),
198 Self::FourPointOne => Some(32_768),
199 Self::FourPointOneMini => Some(32_768),
200 Self::FourPointOneNano => Some(32_768),
201 Self::O1 => Some(100_000),
202 Self::O3Mini => Some(100_000),
203 Self::O3 => Some(100_000),
204 Self::O4Mini => Some(100_000),
205 Self::Five => Some(128_000),
206 Self::FiveMini => Some(128_000),
207 Self::FiveNano => Some(128_000),
208 }
209 }
210
211 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
212 match self {
213 Self::Custom {
214 reasoning_effort, ..
215 } => reasoning_effort.to_owned(),
216 _ => None,
217 }
218 }
219
220 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
221 ///
222 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
223 pub fn supports_parallel_tool_calls(&self) -> bool {
224 match self {
225 Self::ThreePointFiveTurbo
226 | Self::Four
227 | Self::FourTurbo
228 | Self::FourOmni
229 | Self::FourOmniMini
230 | Self::FourPointOne
231 | Self::FourPointOneMini
232 | Self::FourPointOneNano
233 | Self::Five
234 | Self::FiveMini
235 | Self::FiveNano => true,
236 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
237 }
238 }
239
240 /// Returns whether the given model supports the `prompt_cache_key` parameter.
241 ///
242 /// If the model does not support the parameter, do not pass it up.
243 pub fn supports_prompt_cache_key(&self) -> bool {
244 true
245 }
246}
247
248#[derive(Debug, Serialize, Deserialize)]
249pub struct Request {
250 pub model: String,
251 pub messages: Vec<RequestMessage>,
252 pub stream: bool,
253 #[serde(default, skip_serializing_if = "Option::is_none")]
254 pub max_completion_tokens: Option<u64>,
255 #[serde(default, skip_serializing_if = "Vec::is_empty")]
256 pub stop: Vec<String>,
257 pub temperature: f32,
258 #[serde(default, skip_serializing_if = "Option::is_none")]
259 pub tool_choice: Option<ToolChoice>,
260 /// Whether to enable parallel function calling during tool use.
261 #[serde(default, skip_serializing_if = "Option::is_none")]
262 pub parallel_tool_calls: Option<bool>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
264 pub tools: Vec<ToolDefinition>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub prompt_cache_key: Option<String>,
267 #[serde(default, skip_serializing_if = "Option::is_none")]
268 pub reasoning_effort: Option<ReasoningEffort>,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272#[serde(rename_all = "lowercase")]
273pub enum ToolChoice {
274 Auto,
275 Required,
276 None,
277 #[serde(untagged)]
278 Other(ToolDefinition),
279}
280
281#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
282#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
283#[serde(rename_all = "lowercase")]
284pub enum ReasoningEffort {
285 Minimal,
286 Low,
287 Medium,
288 High,
289}
290
291#[derive(Clone, Deserialize, Serialize, Debug)]
292#[serde(tag = "type", rename_all = "snake_case")]
293pub enum ToolDefinition {
294 #[allow(dead_code)]
295 Function { function: FunctionDefinition },
296}
297
298#[derive(Clone, Debug, Serialize, Deserialize)]
299pub struct FunctionDefinition {
300 pub name: String,
301 pub description: Option<String>,
302 pub parameters: Option<Value>,
303}
304
305#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
306#[serde(tag = "role", rename_all = "lowercase")]
307pub enum RequestMessage {
308 Assistant {
309 content: Option<MessageContent>,
310 #[serde(default, skip_serializing_if = "Vec::is_empty")]
311 tool_calls: Vec<ToolCall>,
312 },
313 User {
314 content: MessageContent,
315 },
316 System {
317 content: MessageContent,
318 },
319 Tool {
320 content: MessageContent,
321 tool_call_id: String,
322 },
323}
324
325#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
326#[serde(untagged)]
327pub enum MessageContent {
328 Plain(String),
329 Multipart(Vec<MessagePart>),
330}
331
332impl MessageContent {
333 pub fn empty() -> Self {
334 MessageContent::Multipart(vec![])
335 }
336
337 pub fn push_part(&mut self, part: MessagePart) {
338 match self {
339 MessageContent::Plain(text) => {
340 *self =
341 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
342 }
343 MessageContent::Multipart(parts) if parts.is_empty() => match part {
344 MessagePart::Text { text } => *self = MessageContent::Plain(text),
345 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
346 },
347 MessageContent::Multipart(parts) => parts.push(part),
348 }
349 }
350}
351
352impl From<Vec<MessagePart>> for MessageContent {
353 fn from(mut parts: Vec<MessagePart>) -> Self {
354 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
355 MessageContent::Plain(std::mem::take(text))
356 } else {
357 MessageContent::Multipart(parts)
358 }
359 }
360}
361
362#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
363#[serde(tag = "type")]
364pub enum MessagePart {
365 #[serde(rename = "text")]
366 Text { text: String },
367 #[serde(rename = "image_url")]
368 Image { image_url: ImageUrl },
369}
370
371#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
372pub struct ImageUrl {
373 pub url: String,
374 #[serde(skip_serializing_if = "Option::is_none")]
375 pub detail: Option<String>,
376}
377
378#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
379pub struct ToolCall {
380 pub id: String,
381 #[serde(flatten)]
382 pub content: ToolCallContent,
383}
384
385#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
386#[serde(tag = "type", rename_all = "lowercase")]
387pub enum ToolCallContent {
388 Function { function: FunctionContent },
389}
390
391#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
392pub struct FunctionContent {
393 pub name: String,
394 pub arguments: String,
395}
396
397#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
398pub struct ResponseMessageDelta {
399 pub role: Option<Role>,
400 pub content: Option<String>,
401 #[serde(default, skip_serializing_if = "is_none_or_empty")]
402 pub tool_calls: Option<Vec<ToolCallChunk>>,
403}
404
405#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
406pub struct ToolCallChunk {
407 pub index: usize,
408 pub id: Option<String>,
409
410 // There is also an optional `type` field that would determine if a
411 // function is there. Sometimes this streams in with the `function` before
412 // it streams in the `type`
413 pub function: Option<FunctionChunk>,
414}
415
416#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
417pub struct FunctionChunk {
418 pub name: Option<String>,
419 pub arguments: Option<String>,
420}
421
422#[derive(Serialize, Deserialize, Debug)]
423pub struct Usage {
424 pub prompt_tokens: u64,
425 pub completion_tokens: u64,
426 pub total_tokens: u64,
427}
428
429#[derive(Serialize, Deserialize, Debug)]
430pub struct ChoiceDelta {
431 pub index: u32,
432 pub delta: ResponseMessageDelta,
433 pub finish_reason: Option<String>,
434}
435
436#[derive(Serialize, Deserialize, Debug)]
437pub struct OpenAiError {
438 message: String,
439}
440
441#[derive(Serialize, Deserialize, Debug)]
442#[serde(untagged)]
443pub enum ResponseStreamResult {
444 Ok(ResponseStreamEvent),
445 Err { error: OpenAiError },
446}
447
448#[derive(Serialize, Deserialize, Debug)]
449pub struct ResponseStreamEvent {
450 pub choices: Vec<ChoiceDelta>,
451 pub usage: Option<Usage>,
452}
453
454pub async fn stream_completion(
455 client: &dyn HttpClient,
456 api_url: &str,
457 api_key: &str,
458 request: Request,
459) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
460 let uri = format!("{api_url}/chat/completions");
461 let request_builder = HttpRequest::builder()
462 .method(Method::POST)
463 .uri(uri)
464 .header("Content-Type", "application/json")
465 .header("Authorization", format!("Bearer {}", api_key.trim()));
466
467 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
468 let mut response = client.send(request).await?;
469 if response.status().is_success() {
470 let reader = BufReader::new(response.into_body());
471 Ok(reader
472 .lines()
473 .filter_map(|line| async move {
474 match line {
475 Ok(line) => {
476 let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
477 if line == "[DONE]" {
478 None
479 } else {
480 match serde_json::from_str(line) {
481 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
482 Ok(ResponseStreamResult::Err { error }) => {
483 Some(Err(anyhow!(error.message)))
484 }
485 Err(error) => {
486 log::error!(
487 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
488 Response: `{}`",
489 error,
490 line,
491 );
492 Some(Err(anyhow!(error)))
493 }
494 }
495 }
496 }
497 Err(error) => Some(Err(anyhow!(error))),
498 }
499 })
500 .boxed())
501 } else {
502 let mut body = String::new();
503 response.body_mut().read_to_string(&mut body).await?;
504
505 #[derive(Deserialize)]
506 struct OpenAiResponse {
507 error: OpenAiError,
508 }
509
510 match serde_json::from_str::<OpenAiResponse>(&body) {
511 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
512 "API request to {} failed: {}",
513 api_url,
514 response.error.message,
515 )),
516
517 _ => anyhow::bail!(
518 "API request to {} failed with status {}: {}",
519 api_url,
520 response.status(),
521 body,
522 ),
523 }
524 }
525}
526
527#[derive(Copy, Clone, Serialize, Deserialize)]
528pub enum OpenAiEmbeddingModel {
529 #[serde(rename = "text-embedding-3-small")]
530 TextEmbedding3Small,
531 #[serde(rename = "text-embedding-3-large")]
532 TextEmbedding3Large,
533}
534
535#[derive(Serialize)]
536struct OpenAiEmbeddingRequest<'a> {
537 model: OpenAiEmbeddingModel,
538 input: Vec<&'a str>,
539}
540
541#[derive(Deserialize)]
542pub struct OpenAiEmbeddingResponse {
543 pub data: Vec<OpenAiEmbedding>,
544}
545
546#[derive(Deserialize)]
547pub struct OpenAiEmbedding {
548 pub embedding: Vec<f32>,
549}
550
551pub fn embed<'a>(
552 client: &dyn HttpClient,
553 api_url: &str,
554 api_key: &str,
555 model: OpenAiEmbeddingModel,
556 texts: impl IntoIterator<Item = &'a str>,
557) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
558 let uri = format!("{api_url}/embeddings");
559
560 let request = OpenAiEmbeddingRequest {
561 model,
562 input: texts.into_iter().collect(),
563 };
564 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
565 let request = HttpRequest::builder()
566 .method(Method::POST)
567 .uri(uri)
568 .header("Content-Type", "application/json")
569 .header("Authorization", format!("Bearer {}", api_key.trim()))
570 .body(body)
571 .map(|request| client.send(request));
572
573 async move {
574 let mut response = request?.await?;
575 let mut body = String::new();
576 response.body_mut().read_to_string(&mut body).await?;
577
578 anyhow::ensure!(
579 response.status().is_success(),
580 "error during embedding, status: {:?}, body: {:?}",
581 response.status(),
582 body
583 );
584 let response: OpenAiEmbeddingResponse =
585 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
586 Ok(response)
587 }
588}