1use anyhow::{Context as _, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, future::Future};
7use strum::EnumIter;
8
9pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
10
11fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
12 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
13}
14
15#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
16#[serde(rename_all = "lowercase")]
17pub enum Role {
18 User,
19 Assistant,
20 System,
21 Tool,
22}
23
24impl TryFrom<String> for Role {
25 type Error = anyhow::Error;
26
27 fn try_from(value: String) -> Result<Self> {
28 match value.as_str() {
29 "user" => Ok(Self::User),
30 "assistant" => Ok(Self::Assistant),
31 "system" => Ok(Self::System),
32 "tool" => Ok(Self::Tool),
33 _ => anyhow::bail!("invalid role '{value}'"),
34 }
35 }
36}
37
38impl From<Role> for String {
39 fn from(val: Role) -> Self {
40 match val {
41 Role::User => "user".to_owned(),
42 Role::Assistant => "assistant".to_owned(),
43 Role::System => "system".to_owned(),
44 Role::Tool => "tool".to_owned(),
45 }
46 }
47}
48
49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
50#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
51pub enum Model {
52 #[serde(rename = "gpt-3.5-turbo")]
53 ThreePointFiveTurbo,
54 #[serde(rename = "gpt-4")]
55 Four,
56 #[serde(rename = "gpt-4-turbo")]
57 FourTurbo,
58 #[serde(rename = "gpt-4o")]
59 #[default]
60 FourOmni,
61 #[serde(rename = "gpt-4o-mini")]
62 FourOmniMini,
63 #[serde(rename = "gpt-4.1")]
64 FourPointOne,
65 #[serde(rename = "gpt-4.1-mini")]
66 FourPointOneMini,
67 #[serde(rename = "gpt-4.1-nano")]
68 FourPointOneNano,
69 #[serde(rename = "o1")]
70 O1,
71 #[serde(rename = "o3-mini")]
72 O3Mini,
73 #[serde(rename = "o3")]
74 O3,
75 #[serde(rename = "o4-mini")]
76 O4Mini,
77 #[serde(rename = "gpt-5")]
78 Five,
79 #[serde(rename = "gpt-5-mini")]
80 FiveMini,
81 #[serde(rename = "gpt-5-nano")]
82 FiveNano,
83
84 #[serde(rename = "custom")]
85 Custom {
86 name: String,
87 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
88 display_name: Option<String>,
89 max_tokens: u64,
90 max_output_tokens: Option<u64>,
91 max_completion_tokens: Option<u64>,
92 reasoning_effort: Option<ReasoningEffort>,
93 },
94}
95
96impl Model {
97 pub fn default_fast() -> Self {
98 // TODO: Replace with FiveMini since all other models are deprecated
99 Self::FourPointOneMini
100 }
101
102 pub fn from_id(id: &str) -> Result<Self> {
103 match id {
104 "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
105 "gpt-4" => Ok(Self::Four),
106 "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
107 "gpt-4o" => Ok(Self::FourOmni),
108 "gpt-4o-mini" => Ok(Self::FourOmniMini),
109 "gpt-4.1" => Ok(Self::FourPointOne),
110 "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
111 "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
112 "o1" => Ok(Self::O1),
113 "o3-mini" => Ok(Self::O3Mini),
114 "o3" => Ok(Self::O3),
115 "o4-mini" => Ok(Self::O4Mini),
116 "gpt-5" => Ok(Self::Five),
117 "gpt-5-mini" => Ok(Self::FiveMini),
118 "gpt-5-nano" => Ok(Self::FiveNano),
119 invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
120 }
121 }
122
123 pub fn id(&self) -> &str {
124 match self {
125 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
126 Self::Four => "gpt-4",
127 Self::FourTurbo => "gpt-4-turbo",
128 Self::FourOmni => "gpt-4o",
129 Self::FourOmniMini => "gpt-4o-mini",
130 Self::FourPointOne => "gpt-4.1",
131 Self::FourPointOneMini => "gpt-4.1-mini",
132 Self::FourPointOneNano => "gpt-4.1-nano",
133 Self::O1 => "o1",
134 Self::O3Mini => "o3-mini",
135 Self::O3 => "o3",
136 Self::O4Mini => "o4-mini",
137 Self::Five => "gpt-5",
138 Self::FiveMini => "gpt-5-mini",
139 Self::FiveNano => "gpt-5-nano",
140 Self::Custom { name, .. } => name,
141 }
142 }
143
144 pub fn display_name(&self) -> &str {
145 match self {
146 Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
147 Self::Four => "gpt-4",
148 Self::FourTurbo => "gpt-4-turbo",
149 Self::FourOmni => "gpt-4o",
150 Self::FourOmniMini => "gpt-4o-mini",
151 Self::FourPointOne => "gpt-4.1",
152 Self::FourPointOneMini => "gpt-4.1-mini",
153 Self::FourPointOneNano => "gpt-4.1-nano",
154 Self::O1 => "o1",
155 Self::O3Mini => "o3-mini",
156 Self::O3 => "o3",
157 Self::O4Mini => "o4-mini",
158 Self::Five => "gpt-5",
159 Self::FiveMini => "gpt-5-mini",
160 Self::FiveNano => "gpt-5-nano",
161 Self::Custom {
162 name, display_name, ..
163 } => display_name.as_ref().unwrap_or(name),
164 }
165 }
166
167 pub fn max_token_count(&self) -> u64 {
168 match self {
169 Self::ThreePointFiveTurbo => 16_385,
170 Self::Four => 8_192,
171 Self::FourTurbo => 128_000,
172 Self::FourOmni => 128_000,
173 Self::FourOmniMini => 128_000,
174 Self::FourPointOne => 1_047_576,
175 Self::FourPointOneMini => 1_047_576,
176 Self::FourPointOneNano => 1_047_576,
177 Self::O1 => 200_000,
178 Self::O3Mini => 200_000,
179 Self::O3 => 200_000,
180 Self::O4Mini => 200_000,
181 Self::Five => 272_000,
182 Self::FiveMini => 272_000,
183 Self::FiveNano => 272_000,
184 Self::Custom { max_tokens, .. } => *max_tokens,
185 }
186 }
187
188 pub fn max_output_tokens(&self) -> Option<u64> {
189 match self {
190 Self::Custom {
191 max_output_tokens, ..
192 } => *max_output_tokens,
193 Self::ThreePointFiveTurbo => Some(4_096),
194 Self::Four => Some(8_192),
195 Self::FourTurbo => Some(4_096),
196 Self::FourOmni => Some(16_384),
197 Self::FourOmniMini => Some(16_384),
198 Self::FourPointOne => Some(32_768),
199 Self::FourPointOneMini => Some(32_768),
200 Self::FourPointOneNano => Some(32_768),
201 Self::O1 => Some(100_000),
202 Self::O3Mini => Some(100_000),
203 Self::O3 => Some(100_000),
204 Self::O4Mini => Some(100_000),
205 Self::Five => Some(128_000),
206 Self::FiveMini => Some(128_000),
207 Self::FiveNano => Some(128_000),
208 }
209 }
210
211 pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
212 match self {
213 Self::Custom {
214 reasoning_effort, ..
215 } => reasoning_effort.to_owned(),
216 _ => None,
217 }
218 }
219
220 /// Returns whether the given model supports the `parallel_tool_calls` parameter.
221 ///
222 /// If the model does not support the parameter, do not pass it up, or the API will return an error.
223 pub fn supports_parallel_tool_calls(&self) -> bool {
224 match self {
225 Self::ThreePointFiveTurbo
226 | Self::Four
227 | Self::FourTurbo
228 | Self::FourOmni
229 | Self::FourOmniMini
230 | Self::FourPointOne
231 | Self::FourPointOneMini
232 | Self::FourPointOneNano
233 | Self::Five
234 | Self::FiveMini
235 | Self::FiveNano => true,
236 Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
237 }
238 }
239}
240
241#[derive(Debug, Serialize, Deserialize)]
242pub struct Request {
243 pub model: String,
244 pub messages: Vec<RequestMessage>,
245 pub stream: bool,
246 #[serde(default, skip_serializing_if = "Option::is_none")]
247 pub max_completion_tokens: Option<u64>,
248 #[serde(default, skip_serializing_if = "Vec::is_empty")]
249 pub stop: Vec<String>,
250 pub temperature: f32,
251 #[serde(default, skip_serializing_if = "Option::is_none")]
252 pub tool_choice: Option<ToolChoice>,
253 /// Whether to enable parallel function calling during tool use.
254 #[serde(default, skip_serializing_if = "Option::is_none")]
255 pub parallel_tool_calls: Option<bool>,
256 #[serde(default, skip_serializing_if = "Vec::is_empty")]
257 pub tools: Vec<ToolDefinition>,
258 #[serde(default, skip_serializing_if = "Option::is_none")]
259 pub prompt_cache_key: Option<String>,
260 #[serde(default, skip_serializing_if = "Option::is_none")]
261 pub reasoning_effort: Option<ReasoningEffort>,
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265#[serde(untagged)]
266pub enum ToolChoice {
267 Auto,
268 Required,
269 None,
270 Other(ToolDefinition),
271}
272
273#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
274#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
275#[serde(rename_all = "lowercase")]
276pub enum ReasoningEffort {
277 Minimal,
278 Low,
279 Medium,
280 High,
281}
282
283#[derive(Clone, Deserialize, Serialize, Debug)]
284#[serde(tag = "type", rename_all = "snake_case")]
285pub enum ToolDefinition {
286 #[allow(dead_code)]
287 Function { function: FunctionDefinition },
288}
289
290#[derive(Clone, Debug, Serialize, Deserialize)]
291pub struct FunctionDefinition {
292 pub name: String,
293 pub description: Option<String>,
294 pub parameters: Option<Value>,
295}
296
297#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
298#[serde(tag = "role", rename_all = "lowercase")]
299pub enum RequestMessage {
300 Assistant {
301 content: Option<MessageContent>,
302 #[serde(default, skip_serializing_if = "Vec::is_empty")]
303 tool_calls: Vec<ToolCall>,
304 },
305 User {
306 content: MessageContent,
307 },
308 System {
309 content: MessageContent,
310 },
311 Tool {
312 content: MessageContent,
313 tool_call_id: String,
314 },
315}
316
317#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
318#[serde(untagged)]
319pub enum MessageContent {
320 Plain(String),
321 Multipart(Vec<MessagePart>),
322}
323
324impl MessageContent {
325 pub fn empty() -> Self {
326 MessageContent::Multipart(vec![])
327 }
328
329 pub fn push_part(&mut self, part: MessagePart) {
330 match self {
331 MessageContent::Plain(text) => {
332 *self =
333 MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
334 }
335 MessageContent::Multipart(parts) if parts.is_empty() => match part {
336 MessagePart::Text { text } => *self = MessageContent::Plain(text),
337 MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
338 },
339 MessageContent::Multipart(parts) => parts.push(part),
340 }
341 }
342}
343
344impl From<Vec<MessagePart>> for MessageContent {
345 fn from(mut parts: Vec<MessagePart>) -> Self {
346 if let [MessagePart::Text { text }] = parts.as_mut_slice() {
347 MessageContent::Plain(std::mem::take(text))
348 } else {
349 MessageContent::Multipart(parts)
350 }
351 }
352}
353
354#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
355#[serde(tag = "type")]
356pub enum MessagePart {
357 #[serde(rename = "text")]
358 Text { text: String },
359 #[serde(rename = "image_url")]
360 Image { image_url: ImageUrl },
361}
362
363#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
364pub struct ImageUrl {
365 pub url: String,
366 #[serde(skip_serializing_if = "Option::is_none")]
367 pub detail: Option<String>,
368}
369
370#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
371pub struct ToolCall {
372 pub id: String,
373 #[serde(flatten)]
374 pub content: ToolCallContent,
375}
376
377#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
378#[serde(tag = "type", rename_all = "lowercase")]
379pub enum ToolCallContent {
380 Function { function: FunctionContent },
381}
382
383#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
384pub struct FunctionContent {
385 pub name: String,
386 pub arguments: String,
387}
388
389#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
390pub struct ResponseMessageDelta {
391 pub role: Option<Role>,
392 pub content: Option<String>,
393 #[serde(default, skip_serializing_if = "is_none_or_empty")]
394 pub tool_calls: Option<Vec<ToolCallChunk>>,
395}
396
397#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
398pub struct ToolCallChunk {
399 pub index: usize,
400 pub id: Option<String>,
401
402 // There is also an optional `type` field that would determine if a
403 // function is there. Sometimes this streams in with the `function` before
404 // it streams in the `type`
405 pub function: Option<FunctionChunk>,
406}
407
408#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
409pub struct FunctionChunk {
410 pub name: Option<String>,
411 pub arguments: Option<String>,
412}
413
414#[derive(Serialize, Deserialize, Debug)]
415pub struct Usage {
416 pub prompt_tokens: u64,
417 pub completion_tokens: u64,
418 pub total_tokens: u64,
419}
420
421#[derive(Serialize, Deserialize, Debug)]
422pub struct ChoiceDelta {
423 pub index: u32,
424 pub delta: ResponseMessageDelta,
425 pub finish_reason: Option<String>,
426}
427
428#[derive(Serialize, Deserialize, Debug)]
429#[serde(untagged)]
430pub enum ResponseStreamResult {
431 Ok(ResponseStreamEvent),
432 Err { error: String },
433}
434
435#[derive(Serialize, Deserialize, Debug)]
436pub struct ResponseStreamEvent {
437 pub model: String,
438 pub choices: Vec<ChoiceDelta>,
439 pub usage: Option<Usage>,
440}
441
442pub async fn stream_completion(
443 client: &dyn HttpClient,
444 api_url: &str,
445 api_key: &str,
446 request: Request,
447) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
448 let uri = format!("{api_url}/chat/completions");
449 let request_builder = HttpRequest::builder()
450 .method(Method::POST)
451 .uri(uri)
452 .header("Content-Type", "application/json")
453 .header("Authorization", format!("Bearer {}", api_key));
454
455 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
456 let mut response = client.send(request).await?;
457 if response.status().is_success() {
458 let reader = BufReader::new(response.into_body());
459 Ok(reader
460 .lines()
461 .filter_map(|line| async move {
462 match line {
463 Ok(line) => {
464 let line = line.strip_prefix("data: ")?;
465 if line == "[DONE]" {
466 None
467 } else {
468 match serde_json::from_str(line) {
469 Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
470 Ok(ResponseStreamResult::Err { error }) => {
471 Some(Err(anyhow!(error)))
472 }
473 Err(error) => {
474 log::error!(
475 "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
476 Response: `{}`",
477 error,
478 line,
479 );
480 Some(Err(anyhow!(error)))
481 }
482 }
483 }
484 }
485 Err(error) => Some(Err(anyhow!(error))),
486 }
487 })
488 .boxed())
489 } else {
490 let mut body = String::new();
491 response.body_mut().read_to_string(&mut body).await?;
492
493 #[derive(Deserialize)]
494 struct OpenAiResponse {
495 error: OpenAiError,
496 }
497
498 #[derive(Deserialize)]
499 struct OpenAiError {
500 message: String,
501 }
502
503 match serde_json::from_str::<OpenAiResponse>(&body) {
504 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
505 "API request to {} failed: {}",
506 api_url,
507 response.error.message,
508 )),
509
510 _ => anyhow::bail!(
511 "API request to {} failed with status {}: {}",
512 api_url,
513 response.status(),
514 body,
515 ),
516 }
517 }
518}
519
520#[derive(Copy, Clone, Serialize, Deserialize)]
521pub enum OpenAiEmbeddingModel {
522 #[serde(rename = "text-embedding-3-small")]
523 TextEmbedding3Small,
524 #[serde(rename = "text-embedding-3-large")]
525 TextEmbedding3Large,
526}
527
528#[derive(Serialize)]
529struct OpenAiEmbeddingRequest<'a> {
530 model: OpenAiEmbeddingModel,
531 input: Vec<&'a str>,
532}
533
534#[derive(Deserialize)]
535pub struct OpenAiEmbeddingResponse {
536 pub data: Vec<OpenAiEmbedding>,
537}
538
539#[derive(Deserialize)]
540pub struct OpenAiEmbedding {
541 pub embedding: Vec<f32>,
542}
543
544pub fn embed<'a>(
545 client: &dyn HttpClient,
546 api_url: &str,
547 api_key: &str,
548 model: OpenAiEmbeddingModel,
549 texts: impl IntoIterator<Item = &'a str>,
550) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
551 let uri = format!("{api_url}/embeddings");
552
553 let request = OpenAiEmbeddingRequest {
554 model,
555 input: texts.into_iter().collect(),
556 };
557 let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
558 let request = HttpRequest::builder()
559 .method(Method::POST)
560 .uri(uri)
561 .header("Content-Type", "application/json")
562 .header("Authorization", format!("Bearer {}", api_key))
563 .body(body)
564 .map(|request| client.send(request));
565
566 async move {
567 let mut response = request?.await?;
568 let mut body = String::new();
569 response.body_mut().read_to_string(&mut body).await?;
570
571 anyhow::ensure!(
572 response.status().is_success(),
573 "error during embedding, status: {:?}, body: {:?}",
574 response.status(),
575 body
576 );
577 let response: OpenAiEmbeddingResponse =
578 serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
579 Ok(response)
580 }
581}