open_ai.rs

  1pub mod responses;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  5use http_client::{
  6    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
  7    http::{HeaderMap, HeaderValue},
  8};
  9use serde::{Deserialize, Serialize};
 10use serde_json::Value;
 11pub use settings::OpenAiReasoningEffort as ReasoningEffort;
 12use std::{convert::TryFrom, future::Future};
 13use strum::EnumIter;
 14use thiserror::Error;
 15
 16pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 17
 18fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 19    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 20}
 21
 22#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 23#[serde(rename_all = "lowercase")]
 24pub enum Role {
 25    User,
 26    Assistant,
 27    System,
 28    Tool,
 29}
 30
 31impl TryFrom<String> for Role {
 32    type Error = anyhow::Error;
 33
 34    fn try_from(value: String) -> Result<Self> {
 35        match value.as_str() {
 36            "user" => Ok(Self::User),
 37            "assistant" => Ok(Self::Assistant),
 38            "system" => Ok(Self::System),
 39            "tool" => Ok(Self::Tool),
 40            _ => anyhow::bail!("invalid role '{value}'"),
 41        }
 42    }
 43}
 44
 45impl From<Role> for String {
 46    fn from(val: Role) -> Self {
 47        match val {
 48            Role::User => "user".to_owned(),
 49            Role::Assistant => "assistant".to_owned(),
 50            Role::System => "system".to_owned(),
 51            Role::Tool => "tool".to_owned(),
 52        }
 53    }
 54}
 55
 56#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 57#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 58pub enum Model {
 59    #[serde(rename = "gpt-3.5-turbo")]
 60    ThreePointFiveTurbo,
 61    #[serde(rename = "gpt-4")]
 62    Four,
 63    #[serde(rename = "gpt-4-turbo")]
 64    FourTurbo,
 65    #[serde(rename = "gpt-4o")]
 66    #[default]
 67    FourOmni,
 68    #[serde(rename = "gpt-4o-mini")]
 69    FourOmniMini,
 70    #[serde(rename = "gpt-4.1")]
 71    FourPointOne,
 72    #[serde(rename = "gpt-4.1-mini")]
 73    FourPointOneMini,
 74    #[serde(rename = "gpt-4.1-nano")]
 75    FourPointOneNano,
 76    #[serde(rename = "o1")]
 77    O1,
 78    #[serde(rename = "o3-mini")]
 79    O3Mini,
 80    #[serde(rename = "o3")]
 81    O3,
 82    #[serde(rename = "o4-mini")]
 83    O4Mini,
 84    #[serde(rename = "gpt-5")]
 85    Five,
 86    #[serde(rename = "gpt-5-codex")]
 87    FiveCodex,
 88    #[serde(rename = "gpt-5-mini")]
 89    FiveMini,
 90    #[serde(rename = "gpt-5-nano")]
 91    FiveNano,
 92    #[serde(rename = "gpt-5.1")]
 93    FivePointOne,
 94    #[serde(rename = "gpt-5.2")]
 95    FivePointTwo,
 96    #[serde(rename = "gpt-5.2-codex")]
 97    FivePointTwoCodex,
 98    #[serde(rename = "custom")]
 99    Custom {
100        name: String,
101        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
102        display_name: Option<String>,
103        max_tokens: u64,
104        max_output_tokens: Option<u64>,
105        max_completion_tokens: Option<u64>,
106        reasoning_effort: Option<ReasoningEffort>,
107        #[serde(default = "default_supports_chat_completions")]
108        supports_chat_completions: bool,
109    },
110}
111
112const fn default_supports_chat_completions() -> bool {
113    true
114}
115
116impl Model {
117    pub fn default_fast() -> Self {
118        // TODO: Replace with FiveMini since all other models are deprecated
119        Self::FourPointOneMini
120    }
121
122    pub fn from_id(id: &str) -> Result<Self> {
123        match id {
124            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
125            "gpt-4" => Ok(Self::Four),
126            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
127            "gpt-4o" => Ok(Self::FourOmni),
128            "gpt-4o-mini" => Ok(Self::FourOmniMini),
129            "gpt-4.1" => Ok(Self::FourPointOne),
130            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
131            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
132            "o1" => Ok(Self::O1),
133            "o3-mini" => Ok(Self::O3Mini),
134            "o3" => Ok(Self::O3),
135            "o4-mini" => Ok(Self::O4Mini),
136            "gpt-5" => Ok(Self::Five),
137            "gpt-5-codex" => Ok(Self::FiveCodex),
138            "gpt-5-mini" => Ok(Self::FiveMini),
139            "gpt-5-nano" => Ok(Self::FiveNano),
140            "gpt-5.1" => Ok(Self::FivePointOne),
141            "gpt-5.2" => Ok(Self::FivePointTwo),
142            "gpt-5.2-codex" => Ok(Self::FivePointTwoCodex),
143            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
144        }
145    }
146
147    pub fn id(&self) -> &str {
148        match self {
149            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
150            Self::Four => "gpt-4",
151            Self::FourTurbo => "gpt-4-turbo",
152            Self::FourOmni => "gpt-4o",
153            Self::FourOmniMini => "gpt-4o-mini",
154            Self::FourPointOne => "gpt-4.1",
155            Self::FourPointOneMini => "gpt-4.1-mini",
156            Self::FourPointOneNano => "gpt-4.1-nano",
157            Self::O1 => "o1",
158            Self::O3Mini => "o3-mini",
159            Self::O3 => "o3",
160            Self::O4Mini => "o4-mini",
161            Self::Five => "gpt-5",
162            Self::FiveCodex => "gpt-5-codex",
163            Self::FiveMini => "gpt-5-mini",
164            Self::FiveNano => "gpt-5-nano",
165            Self::FivePointOne => "gpt-5.1",
166            Self::FivePointTwo => "gpt-5.2",
167            Self::FivePointTwoCodex => "gpt-5.2-codex",
168            Self::Custom { name, .. } => name,
169        }
170    }
171
172    pub fn display_name(&self) -> &str {
173        match self {
174            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
175            Self::Four => "gpt-4",
176            Self::FourTurbo => "gpt-4-turbo",
177            Self::FourOmni => "gpt-4o",
178            Self::FourOmniMini => "gpt-4o-mini",
179            Self::FourPointOne => "gpt-4.1",
180            Self::FourPointOneMini => "gpt-4.1-mini",
181            Self::FourPointOneNano => "gpt-4.1-nano",
182            Self::O1 => "o1",
183            Self::O3Mini => "o3-mini",
184            Self::O3 => "o3",
185            Self::O4Mini => "o4-mini",
186            Self::Five => "gpt-5",
187            Self::FiveCodex => "gpt-5-codex",
188            Self::FiveMini => "gpt-5-mini",
189            Self::FiveNano => "gpt-5-nano",
190            Self::FivePointOne => "gpt-5.1",
191            Self::FivePointTwo => "gpt-5.2",
192            Self::FivePointTwoCodex => "gpt-5.2-codex",
193            Self::Custom {
194                name, display_name, ..
195            } => display_name.as_ref().unwrap_or(name),
196        }
197    }
198
199    pub fn max_token_count(&self) -> u64 {
200        match self {
201            Self::ThreePointFiveTurbo => 16_385,
202            Self::Four => 8_192,
203            Self::FourTurbo => 128_000,
204            Self::FourOmni => 128_000,
205            Self::FourOmniMini => 128_000,
206            Self::FourPointOne => 1_047_576,
207            Self::FourPointOneMini => 1_047_576,
208            Self::FourPointOneNano => 1_047_576,
209            Self::O1 => 200_000,
210            Self::O3Mini => 200_000,
211            Self::O3 => 200_000,
212            Self::O4Mini => 200_000,
213            Self::Five => 272_000,
214            Self::FiveCodex => 272_000,
215            Self::FiveMini => 272_000,
216            Self::FiveNano => 272_000,
217            Self::FivePointOne => 400_000,
218            Self::FivePointTwo => 400_000,
219            Self::FivePointTwoCodex => 400_000,
220            Self::Custom { max_tokens, .. } => *max_tokens,
221        }
222    }
223
224    pub fn max_output_tokens(&self) -> Option<u64> {
225        match self {
226            Self::Custom {
227                max_output_tokens, ..
228            } => *max_output_tokens,
229            Self::ThreePointFiveTurbo => Some(4_096),
230            Self::Four => Some(8_192),
231            Self::FourTurbo => Some(4_096),
232            Self::FourOmni => Some(16_384),
233            Self::FourOmniMini => Some(16_384),
234            Self::FourPointOne => Some(32_768),
235            Self::FourPointOneMini => Some(32_768),
236            Self::FourPointOneNano => Some(32_768),
237            Self::O1 => Some(100_000),
238            Self::O3Mini => Some(100_000),
239            Self::O3 => Some(100_000),
240            Self::O4Mini => Some(100_000),
241            Self::Five => Some(128_000),
242            Self::FiveCodex => Some(128_000),
243            Self::FiveMini => Some(128_000),
244            Self::FiveNano => Some(128_000),
245            Self::FivePointOne => Some(128_000),
246            Self::FivePointTwo => Some(128_000),
247            Self::FivePointTwoCodex => Some(128_000),
248        }
249    }
250
251    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
252        match self {
253            Self::Custom {
254                reasoning_effort, ..
255            } => reasoning_effort.to_owned(),
256            _ => None,
257        }
258    }
259
260    pub fn supports_chat_completions(&self) -> bool {
261        match self {
262            Self::Custom {
263                supports_chat_completions,
264                ..
265            } => *supports_chat_completions,
266            Self::FiveCodex | Self::FivePointTwoCodex => false,
267            _ => true,
268        }
269    }
270
271    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
272    ///
273    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
274    pub fn supports_parallel_tool_calls(&self) -> bool {
275        match self {
276            Self::ThreePointFiveTurbo
277            | Self::Four
278            | Self::FourTurbo
279            | Self::FourOmni
280            | Self::FourOmniMini
281            | Self::FourPointOne
282            | Self::FourPointOneMini
283            | Self::FourPointOneNano
284            | Self::Five
285            | Self::FiveCodex
286            | Self::FiveMini
287            | Self::FivePointOne
288            | Self::FivePointTwo
289            | Self::FivePointTwoCodex
290            | Self::FiveNano => true,
291            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
292        }
293    }
294
295    /// Returns whether the given model supports the `prompt_cache_key` parameter.
296    ///
297    /// If the model does not support the parameter, do not pass it up.
298    pub fn supports_prompt_cache_key(&self) -> bool {
299        true
300    }
301}
302
303#[derive(Debug, Serialize, Deserialize)]
304pub struct Request {
305    pub model: String,
306    pub messages: Vec<RequestMessage>,
307    pub stream: bool,
308    #[serde(default, skip_serializing_if = "Option::is_none")]
309    pub max_completion_tokens: Option<u64>,
310    #[serde(default, skip_serializing_if = "Vec::is_empty")]
311    pub stop: Vec<String>,
312    #[serde(default, skip_serializing_if = "Option::is_none")]
313    pub temperature: Option<f32>,
314    #[serde(default, skip_serializing_if = "Option::is_none")]
315    pub tool_choice: Option<ToolChoice>,
316    /// Whether to enable parallel function calling during tool use.
317    #[serde(default, skip_serializing_if = "Option::is_none")]
318    pub parallel_tool_calls: Option<bool>,
319    #[serde(default, skip_serializing_if = "Vec::is_empty")]
320    pub tools: Vec<ToolDefinition>,
321    #[serde(default, skip_serializing_if = "Option::is_none")]
322    pub prompt_cache_key: Option<String>,
323    #[serde(default, skip_serializing_if = "Option::is_none")]
324    pub reasoning_effort: Option<ReasoningEffort>,
325}
326
327#[derive(Debug, Serialize, Deserialize)]
328#[serde(rename_all = "lowercase")]
329pub enum ToolChoice {
330    Auto,
331    Required,
332    None,
333    #[serde(untagged)]
334    Other(ToolDefinition),
335}
336
337#[derive(Clone, Deserialize, Serialize, Debug)]
338#[serde(tag = "type", rename_all = "snake_case")]
339pub enum ToolDefinition {
340    #[allow(dead_code)]
341    Function { function: FunctionDefinition },
342}
343
344#[derive(Clone, Debug, Serialize, Deserialize)]
345pub struct FunctionDefinition {
346    pub name: String,
347    pub description: Option<String>,
348    pub parameters: Option<Value>,
349}
350
351#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
352#[serde(tag = "role", rename_all = "lowercase")]
353pub enum RequestMessage {
354    Assistant {
355        content: Option<MessageContent>,
356        #[serde(default, skip_serializing_if = "Vec::is_empty")]
357        tool_calls: Vec<ToolCall>,
358    },
359    User {
360        content: MessageContent,
361    },
362    System {
363        content: MessageContent,
364    },
365    Tool {
366        content: MessageContent,
367        tool_call_id: String,
368    },
369}
370
371#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
372#[serde(untagged)]
373pub enum MessageContent {
374    Plain(String),
375    Multipart(Vec<MessagePart>),
376}
377
378impl MessageContent {
379    pub fn empty() -> Self {
380        MessageContent::Multipart(vec![])
381    }
382
383    pub fn push_part(&mut self, part: MessagePart) {
384        match self {
385            MessageContent::Plain(text) => {
386                *self =
387                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
388            }
389            MessageContent::Multipart(parts) if parts.is_empty() => match part {
390                MessagePart::Text { text } => *self = MessageContent::Plain(text),
391                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
392            },
393            MessageContent::Multipart(parts) => parts.push(part),
394        }
395    }
396}
397
398impl From<Vec<MessagePart>> for MessageContent {
399    fn from(mut parts: Vec<MessagePart>) -> Self {
400        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
401            MessageContent::Plain(std::mem::take(text))
402        } else {
403            MessageContent::Multipart(parts)
404        }
405    }
406}
407
408#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
409#[serde(tag = "type")]
410pub enum MessagePart {
411    #[serde(rename = "text")]
412    Text { text: String },
413    #[serde(rename = "image_url")]
414    Image { image_url: ImageUrl },
415}
416
417#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
418pub struct ImageUrl {
419    pub url: String,
420    #[serde(skip_serializing_if = "Option::is_none")]
421    pub detail: Option<String>,
422}
423
424#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
425pub struct ToolCall {
426    pub id: String,
427    #[serde(flatten)]
428    pub content: ToolCallContent,
429}
430
431#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
432#[serde(tag = "type", rename_all = "lowercase")]
433pub enum ToolCallContent {
434    Function { function: FunctionContent },
435}
436
437#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
438pub struct FunctionContent {
439    pub name: String,
440    pub arguments: String,
441}
442
443#[derive(Clone, Serialize, Deserialize, Debug)]
444pub struct Response {
445    pub id: String,
446    pub object: String,
447    pub created: u64,
448    pub model: String,
449    pub choices: Vec<Choice>,
450    pub usage: Usage,
451}
452
453#[derive(Clone, Serialize, Deserialize, Debug)]
454pub struct Choice {
455    pub index: u32,
456    pub message: RequestMessage,
457    pub finish_reason: Option<String>,
458}
459
460#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
461pub struct ResponseMessageDelta {
462    pub role: Option<Role>,
463    pub content: Option<String>,
464    #[serde(default, skip_serializing_if = "is_none_or_empty")]
465    pub tool_calls: Option<Vec<ToolCallChunk>>,
466    #[serde(default, skip_serializing_if = "is_none_or_empty")]
467    pub reasoning_content: Option<String>,
468}
469
470#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
471pub struct ToolCallChunk {
472    pub index: usize,
473    pub id: Option<String>,
474
475    // There is also an optional `type` field that would determine if a
476    // function is there. Sometimes this streams in with the `function` before
477    // it streams in the `type`
478    pub function: Option<FunctionChunk>,
479}
480
481#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
482pub struct FunctionChunk {
483    pub name: Option<String>,
484    pub arguments: Option<String>,
485}
486
487#[derive(Clone, Serialize, Deserialize, Debug)]
488pub struct Usage {
489    pub prompt_tokens: u64,
490    pub completion_tokens: u64,
491    pub total_tokens: u64,
492}
493
494#[derive(Serialize, Deserialize, Debug)]
495pub struct ChoiceDelta {
496    pub index: u32,
497    pub delta: Option<ResponseMessageDelta>,
498    pub finish_reason: Option<String>,
499}
500
501#[derive(Error, Debug)]
502pub enum RequestError {
503    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
504    HttpResponseError {
505        provider: String,
506        status_code: StatusCode,
507        body: String,
508        headers: HeaderMap<HeaderValue>,
509    },
510    #[error(transparent)]
511    Other(#[from] anyhow::Error),
512}
513
514#[derive(Serialize, Deserialize, Debug)]
515pub struct ResponseStreamError {
516    message: String,
517}
518
519#[derive(Serialize, Deserialize, Debug)]
520#[serde(untagged)]
521pub enum ResponseStreamResult {
522    Ok(ResponseStreamEvent),
523    Err { error: ResponseStreamError },
524}
525
526#[derive(Serialize, Deserialize, Debug)]
527pub struct ResponseStreamEvent {
528    pub choices: Vec<ChoiceDelta>,
529    pub usage: Option<Usage>,
530}
531
532pub async fn stream_completion(
533    client: &dyn HttpClient,
534    provider_name: &str,
535    api_url: &str,
536    api_key: &str,
537    request: Request,
538) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
539    let uri = format!("{api_url}/chat/completions");
540    let request_builder = HttpRequest::builder()
541        .method(Method::POST)
542        .uri(uri)
543        .header("Content-Type", "application/json")
544        .header("Authorization", format!("Bearer {}", api_key.trim()));
545
546    let request = request_builder
547        .body(AsyncBody::from(
548            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
549        ))
550        .map_err(|e| RequestError::Other(e.into()))?;
551
552    let mut response = client.send(request).await?;
553    if response.status().is_success() {
554        let reader = BufReader::new(response.into_body());
555        Ok(reader
556            .lines()
557            .filter_map(|line| async move {
558                match line {
559                    Ok(line) => {
560                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
561                        if line == "[DONE]" {
562                            None
563                        } else {
564                            match serde_json::from_str(line) {
565                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
566                                Ok(ResponseStreamResult::Err { error }) => {
567                                    Some(Err(anyhow!(error.message)))
568                                }
569                                Err(error) => {
570                                    log::error!(
571                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
572                                        Response: `{}`",
573                                        error,
574                                        line,
575                                    );
576                                    Some(Err(anyhow!(error)))
577                                }
578                            }
579                        }
580                    }
581                    Err(error) => Some(Err(anyhow!(error))),
582                }
583            })
584            .boxed())
585    } else {
586        let mut body = String::new();
587        response
588            .body_mut()
589            .read_to_string(&mut body)
590            .await
591            .map_err(|e| RequestError::Other(e.into()))?;
592
593        Err(RequestError::HttpResponseError {
594            provider: provider_name.to_owned(),
595            status_code: response.status(),
596            body,
597            headers: response.headers().clone(),
598        })
599    }
600}
601
602#[derive(Copy, Clone, Serialize, Deserialize)]
603pub enum OpenAiEmbeddingModel {
604    #[serde(rename = "text-embedding-3-small")]
605    TextEmbedding3Small,
606    #[serde(rename = "text-embedding-3-large")]
607    TextEmbedding3Large,
608}
609
610#[derive(Serialize)]
611struct OpenAiEmbeddingRequest<'a> {
612    model: OpenAiEmbeddingModel,
613    input: Vec<&'a str>,
614}
615
616#[derive(Deserialize)]
617pub struct OpenAiEmbeddingResponse {
618    pub data: Vec<OpenAiEmbedding>,
619}
620
621#[derive(Deserialize)]
622pub struct OpenAiEmbedding {
623    pub embedding: Vec<f32>,
624}
625
626pub fn embed<'a>(
627    client: &dyn HttpClient,
628    api_url: &str,
629    api_key: &str,
630    model: OpenAiEmbeddingModel,
631    texts: impl IntoIterator<Item = &'a str>,
632) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
633    let uri = format!("{api_url}/embeddings");
634
635    let request = OpenAiEmbeddingRequest {
636        model,
637        input: texts.into_iter().collect(),
638    };
639    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
640    let request = HttpRequest::builder()
641        .method(Method::POST)
642        .uri(uri)
643        .header("Content-Type", "application/json")
644        .header("Authorization", format!("Bearer {}", api_key.trim()))
645        .body(body)
646        .map(|request| client.send(request));
647
648    async move {
649        let mut response = request?.await?;
650        let mut body = String::new();
651        response.body_mut().read_to_string(&mut body).await?;
652
653        anyhow::ensure!(
654            response.status().is_success(),
655            "error during embedding, status: {:?}, body: {:?}",
656            response.status(),
657            body
658        );
659        let response: OpenAiEmbeddingResponse =
660            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
661        Ok(response)
662    }
663}