1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, future::Future};
7use strum::EnumIter;
8
9pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
10
11fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
12 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
13}
14
15#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
16#[serde(rename_all = "lowercase")]
17pub enum Role {
18 User,
19 Assistant,
20 System,
21 Tool,
22}
23
24impl TryFrom<String> for Role {
25 type Error = anyhow::Error;
26
27 fn try_from(value: String) -> Result<Self> {
28 match value.as_str() {
29 "user" => Ok(Self::User),
30 "assistant" => Ok(Self::Assistant),
31 "system" => Ok(Self::System),
32 "tool" => Ok(Self::Tool),
33 _ => anyhow::bail!("invalid role '{value}'"),
34 }
35 }
36}
37
38impl From<Role> for String {
39 fn from(val: Role) -> Self {
40 match val {
41 Role::User => "user".to_owned(),
42 Role::Assistant => "assistant".to_owned(),
43 Role::System => "system".to_owned(),
44 Role::Tool => "tool".to_owned(),
45 }
46 }
47}
48
49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
50#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
51pub enum Model {
52 #[serde(rename = "gpt-3.5-turbo")]
53 ThreePointFiveTurbo,
54 #[serde(rename = "gpt-4")]
55 Four,
56 #[serde(rename = "gpt-4-turbo")]
57 FourTurbo,
58 #[serde(rename = "gpt-4o")]
59 #[default]
60 FourOmni,
61 #[serde(rename = "gpt-4o-mini")]
62 FourOmniMini,
63 #[serde(rename = "gpt-4.1")]
64 FourPointOne,
65 #[serde(rename = "gpt-4.1-mini")]
66 FourPointOneMini,
67 #[serde(rename = "gpt-4.1-nano")]
68 FourPointOneNano,
69 #[serde(rename = "o1")]
70 O1,
71 #[serde(rename = "o3-mini")]
72 O3Mini,
73 #[serde(rename = "o3")]
74 O3,
75 #[serde(rename = "o4-mini")]
76 O4Mini,
77 #[serde(rename = "gpt-5")]
78 Five,
79 #[serde(rename = "gpt-5-mini")]
80 FiveMini,
81 #[serde(rename = "gpt-5-nano")]
82 FiveNano,
83
84 #[serde(rename = "custom")]
85 Custom {
86 name: String,
87 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
88 display_name: Option<String>,
89 max_tokens: u64,
90 max_output_tokens: Option<u64>,
91 max_completion_tokens: Option<u64>,
92 reasoning_effort: Option<ReasoningEffort>,
93 },
94}
95
96impl Model {
97 pub fn default_fast() -> Self {
98 // TODO: Replace with FiveMini since all other models are deprecated
99 Self::FourPointOneMini
100 }
101
102 pub fn from_id(id: &str) -> Result<Self> {
103 match id {
104 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
105 "gpt-4" => Ok(Self::Four),
106 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
107 "gpt-4o" => Ok(Self::FourOmni),
108 "gpt-4o-mini" => Ok(Self::FourOmniMini),
109 "gpt-4.1" => Ok(Self::FourPointOne),
110 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
111 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
112 "o1" => Ok(Self::O1),
113 "o3-mini" => Ok(Self::O3Mini),
114 "o3" => Ok(Self::O3),
115 "o4-mini" => Ok(Self::O4Mini),
116 "gpt-5" => Ok(Self::Five),
117 "gpt-5-mini" => Ok(Self::FiveMini),
118 "gpt-5-nano" => Ok(Self::FiveNano),
119 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
120 }
121 }
122
123 pub fn id(&self) -> &str {
124 match self {
125 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
126 Self::Four => "gpt-4",
127 Self::FourTurbo => "gpt-4-turbo",
128 Self::FourOmni => "gpt-4o",
129 Self::FourOmniMini => "gpt-4o-mini",
130 Self::FourPointOne => "gpt-4.1",
131 Self::FourPointOneMini => "gpt-4.1-mini",
132 Self::FourPointOneNano => "gpt-4.1-nano",
133 Self::O1 => "o1",
134 Self::O3Mini => "o3-mini",
135 Self::O3 => "o3",
136 Self::O4Mini => "o4-mini",
137 Self::Five => "gpt-5",
138 Self::FiveMini => "gpt-5-mini",
139 Self::FiveNano => "gpt-5-nano",
140 Self::Custom { name, .. } => name,
141 }
142 }
143
144 pub fn display_name(&self) -> &str {
145 match self {
146 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
147 Self::Four => "gpt-4",
148 Self::FourTurbo => "gpt-4-turbo",
149 Self::FourOmni => "gpt-4o",
150 Self::FourOmniMini => "gpt-4o-mini",
151 Self::FourPointOne => "gpt-4.1",
152 Self::FourPointOneMini => "gpt-4.1-mini",
153 Self::FourPointOneNano => "gpt-4.1-nano",
154 Self::O1 => "o1",
155 Self::O3Mini => "o3-mini",
156 Self::O3 => "o3",
157 Self::O4Mini => "o4-mini",
158 Self::Five => "gpt-5",
159 Self::FiveMini => "gpt-5-mini",
160 Self::FiveNano => "gpt-5-nano",
161 Self::Custom {
162 name, display_name, ..
163 } => display_name.as_ref().unwrap_or(name),
164 }
165 }
166
167 pub fn max_token_count(&self) -> u64 {
168 match self {
169 Self::ThreePointFiveTurbo => 16_385,
170 Self::Four => 8_192,
171 Self::FourTurbo => 128_000,
172 Self::FourOmni => 128_000,
173 Self::FourOmniMini => 128_000,
174 Self::FourPointOne => 1_047_576,
175 Self::FourPointOneMini => 1_047_576,
176 Self::FourPointOneNano => 1_047_576,
177 Self::O1 => 200_000,
178 Self::O3Mini => 200_000,
179 Self::O3 => 200_000,
180 Self::O4Mini => 200_000,
181 Self::Five => 272_000,
182 Self::FiveMini => 272_000,
183 Self::FiveNano => 272_000,
184 Self::Custom { max_tokens, .. } => *max_tokens,
185 }
186 }
187
188 pub fn max_output_tokens(&self) -> Option<u64> {
189 match self {
190 Self::Custom {
191 max_output_tokens, ..
192 } => *max_output_tokens,
193 Self::ThreePointFiveTurbo => Some(4_096),
194 Self::Four => Some(8_192),
195 Self::FourTurbo => Some(4_096),
196 Self::FourOmni => Some(16_384),
197 Self::FourOmniMini => Some(16_384),
198 Self::FourPointOne => Some(32_768),
199 Self::FourPointOneMini => Some(32_768),
200 Self::FourPointOneNano => Some(32_768),
201 Self::O1 => Some(100_000),
202 Self::O3Mini => Some(100_000),
203 Self::O3 => Some(100_000),
204 Self::O4Mini => Some(100_000),
205 Self::Five => Some(128_000),
206 Self::FiveMini => Some(128_000),
207 Self::FiveNano => Some(128_000),
208 }
209 }
210
211 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
212 match self {
213 Self::Custom {
214 reasoning_effort, ..
215 } => reasoning_effort.to_owned(),
216 _ => None,
217 }
218 }
219
220 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
221 ///
222 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
223 pub fn supports_parallel_tool_calls(&self) -> bool {
224 match self {
225 Self::ThreePointFiveTurbo
226 | Self::Four
227 | Self::FourTurbo
228 | Self::FourOmni
229 | Self::FourOmniMini
230 | Self::FourPointOne
231 | Self::FourPointOneMini
232 | Self::FourPointOneNano
233 | Self::Five
234 | Self::FiveMini
235 | Self::FiveNano => true,
236 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
237 }
238 }
239
240 /// Returns whether the given model supports the `prompt_cache_key` parameter.
241 ///
242 /// If the model does not support the parameter, do not pass it up.
243 pub fn supports_prompt_cache_key(&self) -> bool {
244 true
245 }
246}
247
248#[derive(Debug, Serialize, Deserialize)]
249pub struct Request {
250 pub model: String,
251 pub messages: Vec<RequestMessage>,
252 pub stream: bool,
253 #[serde(default, skip_serializing_if = "Option::is_none")]
254 pub max_completion_tokens: Option<u64>,
255 #[serde(default, skip_serializing_if = "Vec::is_empty")]
256 pub stop: Vec<String>,
257 pub temperature: f32,
258 #[serde(default, skip_serializing_if = "Option::is_none")]
259 pub tool_choice: Option<ToolChoice>,
260 /// Whether to enable parallel function calling during tool use.
261 #[serde(default, skip_serializing_if = "Option::is_none")]
262 pub parallel_tool_calls: Option<bool>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
264 pub tools: Vec<ToolDefinition>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub prompt_cache_key: Option<String>,
267 #[serde(default, skip_serializing_if = "Option::is_none")]
268 pub reasoning_effort: Option<ReasoningEffort>,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272#[serde(untagged)]
273pub enum ToolChoice {
274 Auto,
275 Required,
276 None,
277 Other(ToolDefinition),
278}
279
280#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
281#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
282#[serde(rename_all = "lowercase")]
283pub enum ReasoningEffort {
284 Minimal,
285 Low,
286 Medium,
287 High,
288}
289
290#[derive(Clone, Deserialize, Serialize, Debug)]
291#[serde(tag = "type", rename_all = "snake_case")]
292pub enum ToolDefinition {
293 #[allow(dead_code)]
294 Function { function: FunctionDefinition },
295}
296
297#[derive(Clone, Debug, Serialize, Deserialize)]
298pub struct FunctionDefinition {
299 pub name: String,
300 pub description: Option<String>,
301 pub parameters: Option<Value>,
302}
303
304#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
305#[serde(tag = "role", rename_all = "lowercase")]
306pub enum RequestMessage {
307 Assistant {
308 content: Option<MessageContent>,
309 #[serde(default, skip_serializing_if = "Vec::is_empty")]
310 tool_calls: Vec<ToolCall>,
311 },
312 User {
313 content: MessageContent,
314 },
315 System {
316 content: MessageContent,
317 },
318 Tool {
319 content: MessageContent,
320 tool_call_id: String,
321 },
322}
323
324#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
325#[serde(untagged)]
326pub enum MessageContent {
327 Plain(String),
328 Multipart(Vec<MessagePart>),
329}
330
331impl MessageContent {
332 pub fn empty() -> Self {
333 MessageContent::Multipart(vec![])
334 }
335
336 pub fn push_part(&mut self, part: MessagePart) {
337 match self {
338 MessageContent::Plain(text) => {
339 *self =
340 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
341 }
342 MessageContent::Multipart(parts) if parts.is_empty() => match part {
343 MessagePart::Text { text } => *self = MessageContent::Plain(text),
344 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
345 },
346 MessageContent::Multipart(parts) => parts.push(part),
347 }
348 }
349}
350
351impl From<Vec<MessagePart>> for MessageContent {
352 fn from(mut parts: Vec<MessagePart>) -> Self {
353 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
354 MessageContent::Plain(std::mem::take(text))
355 } else {
356 MessageContent::Multipart(parts)
357 }
358 }
359}
360
361#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
362#[serde(tag = "type")]
363pub enum MessagePart {
364 #[serde(rename = "text")]
365 Text { text: String },
366 #[serde(rename = "image_url")]
367 Image { image_url: ImageUrl },
368}
369
370#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
371pub struct ImageUrl {
372 pub url: String,
373 #[serde(skip_serializing_if = "Option::is_none")]
374 pub detail: Option<String>,
375}
376
377#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
378pub struct ToolCall {
379 pub id: String,
380 #[serde(flatten)]
381 pub content: ToolCallContent,
382}
383
384#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
385#[serde(tag = "type", rename_all = "lowercase")]
386pub enum ToolCallContent {
387 Function { function: FunctionContent },
388}
389
390#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
391pub struct FunctionContent {
392 pub name: String,
393 pub arguments: String,
394}
395
396#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
397pub struct ResponseMessageDelta {
398 pub role: Option<Role>,
399 pub content: Option<String>,
400 #[serde(default, skip_serializing_if = "is_none_or_empty")]
401 pub tool_calls: Option<Vec<ToolCallChunk>>,
402}
403
404#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
405pub struct ToolCallChunk {
406 pub index: usize,
407 pub id: Option<String>,
408
409 // There is also an optional `type` field that would determine if a
410 // function is there. Sometimes this streams in with the `function` before
411 // it streams in the `type`
412 pub function: Option<FunctionChunk>,
413}
414
415#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
416pub struct FunctionChunk {
417 pub name: Option<String>,
418 pub arguments: Option<String>,
419}
420
421#[derive(Serialize, Deserialize, Debug)]
422pub struct Usage {
423 pub prompt_tokens: u64,
424 pub completion_tokens: u64,
425 pub total_tokens: u64,
426}
427
428#[derive(Serialize, Deserialize, Debug)]
429pub struct ChoiceDelta {
430 pub index: u32,
431 pub delta: ResponseMessageDelta,
432 pub finish_reason: Option<String>,
433}
434
435#[derive(Serialize, Deserialize, Debug)]
436pub struct OpenAiError {
437 message: String,
438}
439
440#[derive(Serialize, Deserialize, Debug)]
441#[serde(untagged)]
442pub enum ResponseStreamResult {
443 Ok(ResponseStreamEvent),
444 Err { error: OpenAiError },
445}
446
447#[derive(Serialize, Deserialize, Debug)]
448pub struct ResponseStreamEvent {
449 pub model: String,
450 pub choices: Vec<ChoiceDelta>,
451 pub usage: Option<Usage>,
452}
453
454pub async fn stream_completion(
455 client: &dyn HttpClient,
456 api_url: &str,
457 api_key: &str,
458 request: Request,
459) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
460 let uri = format!("{api_url}/chat/completions");
461 let request_builder = HttpRequest::builder()
462 .method(Method::POST)
463 .uri(uri)
464 .header("Content-Type", "application/json")
465 .header("Authorization", format!("Bearer {}", api_key));
466
467 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
468 let mut response = client.send(request).await?;
469 if response.status().is_success() {
470 let reader = BufReader::new(response.into_body());
471 Ok(reader
472 .lines()
473 .filter_map(|line| async move {
474 match line {
475 Ok(line) => {
476 let line = line.strip_prefix("data: ")?;
477 if line == "[DONE]" {
478 None
479 } else {
480 match serde_json::from_str(line) {
481 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
482 Ok(ResponseStreamResult::Err { error }) => {
483 Some(Err(anyhow!(error.message)))
484 }
485 Err(error) => {
486 log::error!(
487 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
488 Response: `{}`",
489 error,
490 line,
491 );
492 Some(Err(anyhow!(error)))
493 }
494 }
495 }
496 }
497 Err(error) => Some(Err(anyhow!(error))),
498 }
499 })
500 .boxed())
501 } else {
502 let mut body = String::new();
503 response.body_mut().read_to_string(&mut body).await?;
504
505 #[derive(Deserialize)]
506 struct OpenAiResponse {
507 error: OpenAiError,
508 }
509
510 match serde_json::from_str::<OpenAiResponse>(&body) {
511 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
512 "API request to {} failed: {}",
513 api_url,
514 response.error.message,
515 )),
516
517 _ => anyhow::bail!(
518 "API request to {} failed with status {}: {}",
519 api_url,
520 response.status(),
521 body,
522 ),
523 }
524 }
525}
526
527#[derive(Copy, Clone, Serialize, Deserialize)]
528pub enum OpenAiEmbeddingModel {
529 #[serde(rename = "text-embedding-3-small")]
530 TextEmbedding3Small,
531 #[serde(rename = "text-embedding-3-large")]
532 TextEmbedding3Large,
533}
534
535#[derive(Serialize)]
536struct OpenAiEmbeddingRequest<'a> {
537 model: OpenAiEmbeddingModel,
538 input: Vec<&'a str>,
539}
540
541#[derive(Deserialize)]
542pub struct OpenAiEmbeddingResponse {
543 pub data: Vec<OpenAiEmbedding>,
544}
545
546#[derive(Deserialize)]
547pub struct OpenAiEmbedding {
548 pub embedding: Vec<f32>,
549}
550
551pub fn embed<'a>(
552 client: &dyn HttpClient,
553 api_url: &str,
554 api_key: &str,
555 model: OpenAiEmbeddingModel,
556 texts: impl IntoIterator<Item = &'a str>,
557) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
558 let uri = format!("{api_url}/embeddings");
559
560 let request = OpenAiEmbeddingRequest {
561 model,
562 input: texts.into_iter().collect(),
563 };
564 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
565 let request = HttpRequest::builder()
566 .method(Method::POST)
567 .uri(uri)
568 .header("Content-Type", "application/json")
569 .header("Authorization", format!("Bearer {}", api_key))
570 .body(body)
571 .map(|request| client.send(request));
572
573 async move {
574 let mut response = request?.await?;
575 let mut body = String::new();
576 response.body_mut().read_to_string(&mut body).await?;
577
578 anyhow::ensure!(
579 response.status().is_success(),
580 "error during embedding, status: {:?}, body: {:?}",
581 response.status(),
582 body
583 );
584 let response: OpenAiEmbeddingResponse =
585 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
586 Ok(response)
587 }
588}