open_ai.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{convert::TryFrom, future::Future};
  7use strum::EnumIter;
  8
  9pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 10
 11fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 12    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 13}
 14
 15#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 16#[serde(rename_all = "lowercase")]
 17pub enum Role {
 18    User,
 19    Assistant,
 20    System,
 21    Tool,
 22}
 23
 24impl TryFrom<String> for Role {
 25    type Error = anyhow::Error;
 26
 27    fn try_from(value: String) -> Result<Self> {
 28        match value.as_str() {
 29            "user" => Ok(Self::User),
 30            "assistant" => Ok(Self::Assistant),
 31            "system" => Ok(Self::System),
 32            "tool" => Ok(Self::Tool),
 33            _ => anyhow::bail!("invalid role '{value}'"),
 34        }
 35    }
 36}
 37
 38impl From<Role> for String {
 39    fn from(val: Role) -> Self {
 40        match val {
 41            Role::User => "user".to_owned(),
 42            Role::Assistant => "assistant".to_owned(),
 43            Role::System => "system".to_owned(),
 44            Role::Tool => "tool".to_owned(),
 45        }
 46    }
 47}
 48
 49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 50#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 51pub enum Model {
 52    #[serde(rename = "gpt-3.5-turbo")]
 53    ThreePointFiveTurbo,
 54    #[serde(rename = "gpt-4")]
 55    Four,
 56    #[serde(rename = "gpt-4-turbo")]
 57    FourTurbo,
 58    #[serde(rename = "gpt-4o")]
 59    #[default]
 60    FourOmni,
 61    #[serde(rename = "gpt-4o-mini")]
 62    FourOmniMini,
 63    #[serde(rename = "gpt-4.1")]
 64    FourPointOne,
 65    #[serde(rename = "gpt-4.1-mini")]
 66    FourPointOneMini,
 67    #[serde(rename = "gpt-4.1-nano")]
 68    FourPointOneNano,
 69    #[serde(rename = "o1")]
 70    O1,
 71    #[serde(rename = "o3-mini")]
 72    O3Mini,
 73    #[serde(rename = "o3")]
 74    O3,
 75    #[serde(rename = "o4-mini")]
 76    O4Mini,
 77    #[serde(rename = "gpt-5")]
 78    Five,
 79    #[serde(rename = "gpt-5-mini")]
 80    FiveMini,
 81    #[serde(rename = "gpt-5-nano")]
 82    FiveNano,
 83
 84    #[serde(rename = "custom")]
 85    Custom {
 86        name: String,
 87        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 88        display_name: Option<String>,
 89        max_tokens: u64,
 90        max_output_tokens: Option<u64>,
 91        max_completion_tokens: Option<u64>,
 92        reasoning_effort: Option<ReasoningEffort>,
 93    },
 94}
 95
 96impl Model {
 97    pub fn default_fast() -> Self {
 98        // TODO: Replace with FiveMini since all other models are deprecated
 99        Self::FourPointOneMini
100    }
101
102    pub fn from_id(id: &str) -> Result<Self> {
103        match id {
104            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
105            "gpt-4" => Ok(Self::Four),
106            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
107            "gpt-4o" => Ok(Self::FourOmni),
108            "gpt-4o-mini" => Ok(Self::FourOmniMini),
109            "gpt-4.1" => Ok(Self::FourPointOne),
110            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
111            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
112            "o1" => Ok(Self::O1),
113            "o3-mini" => Ok(Self::O3Mini),
114            "o3" => Ok(Self::O3),
115            "o4-mini" => Ok(Self::O4Mini),
116            "gpt-5" => Ok(Self::Five),
117            "gpt-5-mini" => Ok(Self::FiveMini),
118            "gpt-5-nano" => Ok(Self::FiveNano),
119            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
120        }
121    }
122
123    pub fn id(&self) -> &str {
124        match self {
125            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
126            Self::Four => "gpt-4",
127            Self::FourTurbo => "gpt-4-turbo",
128            Self::FourOmni => "gpt-4o",
129            Self::FourOmniMini => "gpt-4o-mini",
130            Self::FourPointOne => "gpt-4.1",
131            Self::FourPointOneMini => "gpt-4.1-mini",
132            Self::FourPointOneNano => "gpt-4.1-nano",
133            Self::O1 => "o1",
134            Self::O3Mini => "o3-mini",
135            Self::O3 => "o3",
136            Self::O4Mini => "o4-mini",
137            Self::Five => "gpt-5",
138            Self::FiveMini => "gpt-5-mini",
139            Self::FiveNano => "gpt-5-nano",
140            Self::Custom { name, .. } => name,
141        }
142    }
143
144    pub fn display_name(&self) -> &str {
145        match self {
146            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
147            Self::Four => "gpt-4",
148            Self::FourTurbo => "gpt-4-turbo",
149            Self::FourOmni => "gpt-4o",
150            Self::FourOmniMini => "gpt-4o-mini",
151            Self::FourPointOne => "gpt-4.1",
152            Self::FourPointOneMini => "gpt-4.1-mini",
153            Self::FourPointOneNano => "gpt-4.1-nano",
154            Self::O1 => "o1",
155            Self::O3Mini => "o3-mini",
156            Self::O3 => "o3",
157            Self::O4Mini => "o4-mini",
158            Self::Five => "gpt-5",
159            Self::FiveMini => "gpt-5-mini",
160            Self::FiveNano => "gpt-5-nano",
161            Self::Custom {
162                name, display_name, ..
163            } => display_name.as_ref().unwrap_or(name),
164        }
165    }
166
167    pub fn max_token_count(&self) -> u64 {
168        match self {
169            Self::ThreePointFiveTurbo => 16_385,
170            Self::Four => 8_192,
171            Self::FourTurbo => 128_000,
172            Self::FourOmni => 128_000,
173            Self::FourOmniMini => 128_000,
174            Self::FourPointOne => 1_047_576,
175            Self::FourPointOneMini => 1_047_576,
176            Self::FourPointOneNano => 1_047_576,
177            Self::O1 => 200_000,
178            Self::O3Mini => 200_000,
179            Self::O3 => 200_000,
180            Self::O4Mini => 200_000,
181            Self::Five => 272_000,
182            Self::FiveMini => 272_000,
183            Self::FiveNano => 272_000,
184            Self::Custom { max_tokens, .. } => *max_tokens,
185        }
186    }
187
188    pub fn max_output_tokens(&self) -> Option<u64> {
189        match self {
190            Self::Custom {
191                max_output_tokens, ..
192            } => *max_output_tokens,
193            Self::ThreePointFiveTurbo => Some(4_096),
194            Self::Four => Some(8_192),
195            Self::FourTurbo => Some(4_096),
196            Self::FourOmni => Some(16_384),
197            Self::FourOmniMini => Some(16_384),
198            Self::FourPointOne => Some(32_768),
199            Self::FourPointOneMini => Some(32_768),
200            Self::FourPointOneNano => Some(32_768),
201            Self::O1 => Some(100_000),
202            Self::O3Mini => Some(100_000),
203            Self::O3 => Some(100_000),
204            Self::O4Mini => Some(100_000),
205            Self::Five => Some(128_000),
206            Self::FiveMini => Some(128_000),
207            Self::FiveNano => Some(128_000),
208        }
209    }
210
211    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
212        match self {
213            Self::Custom {
214                reasoning_effort, ..
215            } => reasoning_effort.to_owned(),
216            _ => None,
217        }
218    }
219
220    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
221    ///
222    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
223    pub fn supports_parallel_tool_calls(&self) -> bool {
224        match self {
225            Self::ThreePointFiveTurbo
226            | Self::Four
227            | Self::FourTurbo
228            | Self::FourOmni
229            | Self::FourOmniMini
230            | Self::FourPointOne
231            | Self::FourPointOneMini
232            | Self::FourPointOneNano
233            | Self::Five
234            | Self::FiveMini
235            | Self::FiveNano => true,
236            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
237        }
238    }
239
240    /// Returns whether the given model supports the `prompt_cache_key` parameter.
241    ///
242    /// If the model does not support the parameter, do not pass it up.
243    pub fn supports_prompt_cache_key(&self) -> bool {
244        true
245    }
246}
247
248#[derive(Debug, Serialize, Deserialize)]
249pub struct Request {
250    pub model: String,
251    pub messages: Vec<RequestMessage>,
252    pub stream: bool,
253    #[serde(default, skip_serializing_if = "Option::is_none")]
254    pub max_completion_tokens: Option<u64>,
255    #[serde(default, skip_serializing_if = "Vec::is_empty")]
256    pub stop: Vec<String>,
257    pub temperature: f32,
258    #[serde(default, skip_serializing_if = "Option::is_none")]
259    pub tool_choice: Option<ToolChoice>,
260    /// Whether to enable parallel function calling during tool use.
261    #[serde(default, skip_serializing_if = "Option::is_none")]
262    pub parallel_tool_calls: Option<bool>,
263    #[serde(default, skip_serializing_if = "Vec::is_empty")]
264    pub tools: Vec<ToolDefinition>,
265    #[serde(default, skip_serializing_if = "Option::is_none")]
266    pub prompt_cache_key: Option<String>,
267    #[serde(default, skip_serializing_if = "Option::is_none")]
268    pub reasoning_effort: Option<ReasoningEffort>,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272#[serde(rename_all = "lowercase")]
273pub enum ToolChoice {
274    Auto,
275    Required,
276    None,
277    #[serde(untagged)]
278    Other(ToolDefinition),
279}
280
281#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
282#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
283#[serde(rename_all = "lowercase")]
284pub enum ReasoningEffort {
285    Minimal,
286    Low,
287    Medium,
288    High,
289}
290
291#[derive(Clone, Deserialize, Serialize, Debug)]
292#[serde(tag = "type", rename_all = "snake_case")]
293pub enum ToolDefinition {
294    #[allow(dead_code)]
295    Function { function: FunctionDefinition },
296}
297
298#[derive(Clone, Debug, Serialize, Deserialize)]
299pub struct FunctionDefinition {
300    pub name: String,
301    pub description: Option<String>,
302    pub parameters: Option<Value>,
303}
304
305#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
306#[serde(tag = "role", rename_all = "lowercase")]
307pub enum RequestMessage {
308    Assistant {
309        content: Option<MessageContent>,
310        #[serde(default, skip_serializing_if = "Vec::is_empty")]
311        tool_calls: Vec<ToolCall>,
312    },
313    User {
314        content: MessageContent,
315    },
316    System {
317        content: MessageContent,
318    },
319    Tool {
320        content: MessageContent,
321        tool_call_id: String,
322    },
323}
324
325#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
326#[serde(untagged)]
327pub enum MessageContent {
328    Plain(String),
329    Multipart(Vec<MessagePart>),
330}
331
332impl MessageContent {
333    pub fn empty() -> Self {
334        MessageContent::Multipart(vec![])
335    }
336
337    pub fn push_part(&mut self, part: MessagePart) {
338        match self {
339            MessageContent::Plain(text) => {
340                *self =
341                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
342            }
343            MessageContent::Multipart(parts) if parts.is_empty() => match part {
344                MessagePart::Text { text } => *self = MessageContent::Plain(text),
345                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
346            },
347            MessageContent::Multipart(parts) => parts.push(part),
348        }
349    }
350}
351
352impl From<Vec<MessagePart>> for MessageContent {
353    fn from(mut parts: Vec<MessagePart>) -> Self {
354        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
355            MessageContent::Plain(std::mem::take(text))
356        } else {
357            MessageContent::Multipart(parts)
358        }
359    }
360}
361
362#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
363#[serde(tag = "type")]
364pub enum MessagePart {
365    #[serde(rename = "text")]
366    Text { text: String },
367    #[serde(rename = "image_url")]
368    Image { image_url: ImageUrl },
369}
370
371#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
372pub struct ImageUrl {
373    pub url: String,
374    #[serde(skip_serializing_if = "Option::is_none")]
375    pub detail: Option<String>,
376}
377
378#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
379pub struct ToolCall {
380    pub id: String,
381    #[serde(flatten)]
382    pub content: ToolCallContent,
383}
384
385#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
386#[serde(tag = "type", rename_all = "lowercase")]
387pub enum ToolCallContent {
388    Function { function: FunctionContent },
389}
390
391#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
392pub struct FunctionContent {
393    pub name: String,
394    pub arguments: String,
395}
396
397#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
398pub struct ResponseMessageDelta {
399    pub role: Option<Role>,
400    pub content: Option<String>,
401    #[serde(default, skip_serializing_if = "is_none_or_empty")]
402    pub tool_calls: Option<Vec<ToolCallChunk>>,
403}
404
405#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
406pub struct ToolCallChunk {
407    pub index: usize,
408    pub id: Option<String>,
409
410    // There is also an optional `type` field that would determine if a
411    // function is there. Sometimes this streams in with the `function` before
412    // it streams in the `type`
413    pub function: Option<FunctionChunk>,
414}
415
416#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
417pub struct FunctionChunk {
418    pub name: Option<String>,
419    pub arguments: Option<String>,
420}
421
422#[derive(Serialize, Deserialize, Debug)]
423pub struct Usage {
424    pub prompt_tokens: u64,
425    pub completion_tokens: u64,
426    pub total_tokens: u64,
427}
428
429#[derive(Serialize, Deserialize, Debug)]
430pub struct ChoiceDelta {
431    pub index: u32,
432    pub delta: ResponseMessageDelta,
433    pub finish_reason: Option<String>,
434}
435
436#[derive(Serialize, Deserialize, Debug)]
437pub struct OpenAiError {
438    message: String,
439}
440
441#[derive(Serialize, Deserialize, Debug)]
442#[serde(untagged)]
443pub enum ResponseStreamResult {
444    Ok(ResponseStreamEvent),
445    Err { error: OpenAiError },
446}
447
448#[derive(Serialize, Deserialize, Debug)]
449pub struct ResponseStreamEvent {
450    pub choices: Vec<ChoiceDelta>,
451    pub usage: Option<Usage>,
452}
453
454pub async fn stream_completion(
455    client: &dyn HttpClient,
456    api_url: &str,
457    api_key: &str,
458    request: Request,
459) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
460    let uri = format!("{api_url}/chat/completions");
461    let request_builder = HttpRequest::builder()
462        .method(Method::POST)
463        .uri(uri)
464        .header("Content-Type", "application/json")
465        .header("Authorization", format!("Bearer {}", api_key.trim()));
466
467    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
468    let mut response = client.send(request).await?;
469    if response.status().is_success() {
470        let reader = BufReader::new(response.into_body());
471        Ok(reader
472            .lines()
473            .filter_map(|line| async move {
474                match line {
475                    Ok(line) => {
476                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
477                        if line == "[DONE]" {
478                            None
479                        } else {
480                            match serde_json::from_str(line) {
481                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
482                                Ok(ResponseStreamResult::Err { error }) => {
483                                    Some(Err(anyhow!(error.message)))
484                                }
485                                Err(error) => {
486                                    log::error!(
487                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
488                                        Response: `{}`",
489                                        error,
490                                        line,
491                                    );
492                                    Some(Err(anyhow!(error)))
493                                }
494                            }
495                        }
496                    }
497                    Err(error) => Some(Err(anyhow!(error))),
498                }
499            })
500            .boxed())
501    } else {
502        let mut body = String::new();
503        response.body_mut().read_to_string(&mut body).await?;
504
505        #[derive(Deserialize)]
506        struct OpenAiResponse {
507            error: OpenAiError,
508        }
509
510        match serde_json::from_str::<OpenAiResponse>(&body) {
511            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
512                "API request to {} failed: {}",
513                api_url,
514                response.error.message,
515            )),
516
517            _ => anyhow::bail!(
518                "API request to {} failed with status {}: {}",
519                api_url,
520                response.status(),
521                body,
522            ),
523        }
524    }
525}
526
527#[derive(Copy, Clone, Serialize, Deserialize)]
528pub enum OpenAiEmbeddingModel {
529    #[serde(rename = "text-embedding-3-small")]
530    TextEmbedding3Small,
531    #[serde(rename = "text-embedding-3-large")]
532    TextEmbedding3Large,
533}
534
535#[derive(Serialize)]
536struct OpenAiEmbeddingRequest<'a> {
537    model: OpenAiEmbeddingModel,
538    input: Vec<&'a str>,
539}
540
541#[derive(Deserialize)]
542pub struct OpenAiEmbeddingResponse {
543    pub data: Vec<OpenAiEmbedding>,
544}
545
546#[derive(Deserialize)]
547pub struct OpenAiEmbedding {
548    pub embedding: Vec<f32>,
549}
550
551pub fn embed<'a>(
552    client: &dyn HttpClient,
553    api_url: &str,
554    api_key: &str,
555    model: OpenAiEmbeddingModel,
556    texts: impl IntoIterator<Item = &'a str>,
557) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
558    let uri = format!("{api_url}/embeddings");
559
560    let request = OpenAiEmbeddingRequest {
561        model,
562        input: texts.into_iter().collect(),
563    };
564    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
565    let request = HttpRequest::builder()
566        .method(Method::POST)
567        .uri(uri)
568        .header("Content-Type", "application/json")
569        .header("Authorization", format!("Bearer {}", api_key.trim()))
570        .body(body)
571        .map(|request| client.send(request));
572
573    async move {
574        let mut response = request?.await?;
575        let mut body = String::new();
576        response.body_mut().read_to_string(&mut body).await?;
577
578        anyhow::ensure!(
579            response.status().is_success(),
580            "error during embedding, status: {:?}, body: {:?}",
581            response.status(),
582            body
583        );
584        let response: OpenAiEmbeddingResponse =
585            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
586        Ok(response)
587    }
588}