open_ai.rs

  1pub mod responses;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  5use http_client::{
  6    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
  7    http::{HeaderMap, HeaderValue},
  8};
  9use serde::{Deserialize, Serialize};
 10use serde_json::Value;
 11pub use settings::OpenAiReasoningEffort as ReasoningEffort;
 12use std::{convert::TryFrom, future::Future};
 13use strum::EnumIter;
 14use thiserror::Error;
 15
 16pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 17
 18fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 19    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 20}
 21
 22#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 23#[serde(rename_all = "lowercase")]
 24pub enum Role {
 25    User,
 26    Assistant,
 27    System,
 28    Tool,
 29}
 30
 31impl TryFrom<String> for Role {
 32    type Error = anyhow::Error;
 33
 34    fn try_from(value: String) -> Result<Self> {
 35        match value.as_str() {
 36            "user" => Ok(Self::User),
 37            "assistant" => Ok(Self::Assistant),
 38            "system" => Ok(Self::System),
 39            "tool" => Ok(Self::Tool),
 40            _ => anyhow::bail!("invalid role '{value}'"),
 41        }
 42    }
 43}
 44
 45impl From<Role> for String {
 46    fn from(val: Role) -> Self {
 47        match val {
 48            Role::User => "user".to_owned(),
 49            Role::Assistant => "assistant".to_owned(),
 50            Role::System => "system".to_owned(),
 51            Role::Tool => "tool".to_owned(),
 52        }
 53    }
 54}
 55
 56#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 57#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 58pub enum Model {
 59    #[serde(rename = "gpt-3.5-turbo")]
 60    ThreePointFiveTurbo,
 61    #[serde(rename = "gpt-4")]
 62    Four,
 63    #[serde(rename = "gpt-4-turbo")]
 64    FourTurbo,
 65    #[serde(rename = "gpt-4o")]
 66    #[default]
 67    FourOmni,
 68    #[serde(rename = "gpt-4o-mini")]
 69    FourOmniMini,
 70    #[serde(rename = "gpt-4.1")]
 71    FourPointOne,
 72    #[serde(rename = "gpt-4.1-mini")]
 73    FourPointOneMini,
 74    #[serde(rename = "gpt-4.1-nano")]
 75    FourPointOneNano,
 76    #[serde(rename = "o1")]
 77    O1,
 78    #[serde(rename = "o3-mini")]
 79    O3Mini,
 80    #[serde(rename = "o3")]
 81    O3,
 82    #[serde(rename = "o4-mini")]
 83    O4Mini,
 84    #[serde(rename = "gpt-5")]
 85    Five,
 86    #[serde(rename = "gpt-5-codex")]
 87    FiveCodex,
 88    #[serde(rename = "gpt-5-mini")]
 89    FiveMini,
 90    #[serde(rename = "gpt-5-nano")]
 91    FiveNano,
 92    #[serde(rename = "gpt-5.1")]
 93    FivePointOne,
 94    #[serde(rename = "gpt-5.2")]
 95    FivePointTwo,
 96    #[serde(rename = "gpt-5.2-codex")]
 97    FivePointTwoCodex,
 98    #[serde(rename = "custom")]
 99    Custom {
100        name: String,
101        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
102        display_name: Option<String>,
103        max_tokens: u64,
104        max_output_tokens: Option<u64>,
105        max_completion_tokens: Option<u64>,
106        reasoning_effort: Option<ReasoningEffort>,
107        #[serde(default = "default_supports_chat_completions")]
108        supports_chat_completions: bool,
109    },
110}
111
112const fn default_supports_chat_completions() -> bool {
113    true
114}
115
116impl Model {
117    pub fn default_fast() -> Self {
118        // TODO: Replace with FiveMini since all other models are deprecated
119        Self::FourPointOneMini
120    }
121
122    pub fn from_id(id: &str) -> Result<Self> {
123        match id {
124            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
125            "gpt-4" => Ok(Self::Four),
126            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
127            "gpt-4o" => Ok(Self::FourOmni),
128            "gpt-4o-mini" => Ok(Self::FourOmniMini),
129            "gpt-4.1" => Ok(Self::FourPointOne),
130            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
131            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
132            "o1" => Ok(Self::O1),
133            "o3-mini" => Ok(Self::O3Mini),
134            "o3" => Ok(Self::O3),
135            "o4-mini" => Ok(Self::O4Mini),
136            "gpt-5" => Ok(Self::Five),
137            "gpt-5-codex" => Ok(Self::FiveCodex),
138            "gpt-5-mini" => Ok(Self::FiveMini),
139            "gpt-5-nano" => Ok(Self::FiveNano),
140            "gpt-5.1" => Ok(Self::FivePointOne),
141            "gpt-5.2" => Ok(Self::FivePointTwo),
142            "gpt-5.2-codex" => Ok(Self::FivePointTwoCodex),
143            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
144        }
145    }
146
147    pub fn id(&self) -> &str {
148        match self {
149            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
150            Self::Four => "gpt-4",
151            Self::FourTurbo => "gpt-4-turbo",
152            Self::FourOmni => "gpt-4o",
153            Self::FourOmniMini => "gpt-4o-mini",
154            Self::FourPointOne => "gpt-4.1",
155            Self::FourPointOneMini => "gpt-4.1-mini",
156            Self::FourPointOneNano => "gpt-4.1-nano",
157            Self::O1 => "o1",
158            Self::O3Mini => "o3-mini",
159            Self::O3 => "o3",
160            Self::O4Mini => "o4-mini",
161            Self::Five => "gpt-5",
162            Self::FiveCodex => "gpt-5-codex",
163            Self::FiveMini => "gpt-5-mini",
164            Self::FiveNano => "gpt-5-nano",
165            Self::FivePointOne => "gpt-5.1",
166            Self::FivePointTwo => "gpt-5.2",
167            Self::FivePointTwoCodex => "gpt-5.2-codex",
168            Self::Custom { name, .. } => name,
169        }
170    }
171
172    pub fn display_name(&self) -> &str {
173        match self {
174            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
175            Self::Four => "gpt-4",
176            Self::FourTurbo => "gpt-4-turbo",
177            Self::FourOmni => "gpt-4o",
178            Self::FourOmniMini => "gpt-4o-mini",
179            Self::FourPointOne => "gpt-4.1",
180            Self::FourPointOneMini => "gpt-4.1-mini",
181            Self::FourPointOneNano => "gpt-4.1-nano",
182            Self::O1 => "o1",
183            Self::O3Mini => "o3-mini",
184            Self::O3 => "o3",
185            Self::O4Mini => "o4-mini",
186            Self::Five => "gpt-5",
187            Self::FiveCodex => "gpt-5-codex",
188            Self::FiveMini => "gpt-5-mini",
189            Self::FiveNano => "gpt-5-nano",
190            Self::FivePointOne => "gpt-5.1",
191            Self::FivePointTwo => "gpt-5.2",
192            Self::FivePointTwoCodex => "gpt-5.2-codex",
193            Self::Custom {
194                name, display_name, ..
195            } => display_name.as_ref().unwrap_or(name),
196        }
197    }
198
199    pub fn max_token_count(&self) -> u64 {
200        match self {
201            Self::ThreePointFiveTurbo => 16_385,
202            Self::Four => 8_192,
203            Self::FourTurbo => 128_000,
204            Self::FourOmni => 128_000,
205            Self::FourOmniMini => 128_000,
206            Self::FourPointOne => 1_047_576,
207            Self::FourPointOneMini => 1_047_576,
208            Self::FourPointOneNano => 1_047_576,
209            Self::O1 => 200_000,
210            Self::O3Mini => 200_000,
211            Self::O3 => 200_000,
212            Self::O4Mini => 200_000,
213            Self::Five => 272_000,
214            Self::FiveCodex => 272_000,
215            Self::FiveMini => 272_000,
216            Self::FiveNano => 272_000,
217            Self::FivePointOne => 400_000,
218            Self::FivePointTwo => 400_000,
219            Self::FivePointTwoCodex => 400_000,
220            Self::Custom { max_tokens, .. } => *max_tokens,
221        }
222    }
223
224    pub fn max_output_tokens(&self) -> Option<u64> {
225        match self {
226            Self::Custom {
227                max_output_tokens, ..
228            } => *max_output_tokens,
229            Self::ThreePointFiveTurbo => Some(4_096),
230            Self::Four => Some(8_192),
231            Self::FourTurbo => Some(4_096),
232            Self::FourOmni => Some(16_384),
233            Self::FourOmniMini => Some(16_384),
234            Self::FourPointOne => Some(32_768),
235            Self::FourPointOneMini => Some(32_768),
236            Self::FourPointOneNano => Some(32_768),
237            Self::O1 => Some(100_000),
238            Self::O3Mini => Some(100_000),
239            Self::O3 => Some(100_000),
240            Self::O4Mini => Some(100_000),
241            Self::Five => Some(128_000),
242            Self::FiveCodex => Some(128_000),
243            Self::FiveMini => Some(128_000),
244            Self::FiveNano => Some(128_000),
245            Self::FivePointOne => Some(128_000),
246            Self::FivePointTwo => Some(128_000),
247            Self::FivePointTwoCodex => Some(128_000),
248        }
249    }
250
251    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
252        match self {
253            Self::Custom {
254                reasoning_effort, ..
255            } => reasoning_effort.to_owned(),
256            _ => None,
257        }
258    }
259
260    pub fn supports_chat_completions(&self) -> bool {
261        match self {
262            Self::Custom {
263                supports_chat_completions,
264                ..
265            } => *supports_chat_completions,
266            Self::FiveCodex | Self::FivePointTwoCodex => false,
267            _ => true,
268        }
269    }
270
271    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
272    ///
273    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
274    pub fn supports_parallel_tool_calls(&self) -> bool {
275        match self {
276            Self::ThreePointFiveTurbo
277            | Self::Four
278            | Self::FourTurbo
279            | Self::FourOmni
280            | Self::FourOmniMini
281            | Self::FourPointOne
282            | Self::FourPointOneMini
283            | Self::FourPointOneNano
284            | Self::Five
285            | Self::FiveCodex
286            | Self::FiveMini
287            | Self::FivePointOne
288            | Self::FivePointTwo
289            | Self::FivePointTwoCodex
290            | Self::FiveNano => true,
291            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
292        }
293    }
294
295    /// Returns whether the given model supports the `prompt_cache_key` parameter.
296    ///
297    /// If the model does not support the parameter, do not pass it up.
298    pub fn supports_prompt_cache_key(&self) -> bool {
299        true
300    }
301}
302
303#[derive(Debug, Serialize, Deserialize)]
304pub struct Request {
305    pub model: String,
306    pub messages: Vec<RequestMessage>,
307    pub stream: bool,
308    #[serde(default, skip_serializing_if = "Option::is_none")]
309    pub max_completion_tokens: Option<u64>,
310    #[serde(default, skip_serializing_if = "Vec::is_empty")]
311    pub stop: Vec<String>,
312    #[serde(default, skip_serializing_if = "Option::is_none")]
313    pub temperature: Option<f32>,
314    #[serde(default, skip_serializing_if = "Option::is_none")]
315    pub tool_choice: Option<ToolChoice>,
316    /// Whether to enable parallel function calling during tool use.
317    #[serde(default, skip_serializing_if = "Option::is_none")]
318    pub parallel_tool_calls: Option<bool>,
319    #[serde(default, skip_serializing_if = "Vec::is_empty")]
320    pub tools: Vec<ToolDefinition>,
321    #[serde(default, skip_serializing_if = "Option::is_none")]
322    pub prompt_cache_key: Option<String>,
323    #[serde(default, skip_serializing_if = "Option::is_none")]
324    pub reasoning_effort: Option<ReasoningEffort>,
325}
326
327#[derive(Debug, Serialize, Deserialize)]
328#[serde(rename_all = "lowercase")]
329pub enum ToolChoice {
330    Auto,
331    Required,
332    None,
333    #[serde(untagged)]
334    Other(ToolDefinition),
335}
336
337#[derive(Clone, Deserialize, Serialize, Debug)]
338#[serde(tag = "type", rename_all = "snake_case")]
339pub enum ToolDefinition {
340    #[allow(dead_code)]
341    Function { function: FunctionDefinition },
342}
343
344#[derive(Clone, Debug, Serialize, Deserialize)]
345pub struct FunctionDefinition {
346    pub name: String,
347    pub description: Option<String>,
348    pub parameters: Option<Value>,
349}
350
351#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
352#[serde(tag = "role", rename_all = "lowercase")]
353pub enum RequestMessage {
354    Assistant {
355        content: Option<MessageContent>,
356        #[serde(default, skip_serializing_if = "Vec::is_empty")]
357        tool_calls: Vec<ToolCall>,
358    },
359    User {
360        content: MessageContent,
361    },
362    System {
363        content: MessageContent,
364    },
365    Tool {
366        content: MessageContent,
367        tool_call_id: String,
368    },
369}
370
371#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
372#[serde(untagged)]
373pub enum MessageContent {
374    Plain(String),
375    Multipart(Vec<MessagePart>),
376}
377
378impl MessageContent {
379    pub fn empty() -> Self {
380        MessageContent::Multipart(vec![])
381    }
382
383    pub fn push_part(&mut self, part: MessagePart) {
384        match self {
385            MessageContent::Plain(text) => {
386                *self =
387                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
388            }
389            MessageContent::Multipart(parts) if parts.is_empty() => match part {
390                MessagePart::Text { text } => *self = MessageContent::Plain(text),
391                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
392            },
393            MessageContent::Multipart(parts) => parts.push(part),
394        }
395    }
396}
397
398impl From<Vec<MessagePart>> for MessageContent {
399    fn from(mut parts: Vec<MessagePart>) -> Self {
400        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
401            MessageContent::Plain(std::mem::take(text))
402        } else {
403            MessageContent::Multipart(parts)
404        }
405    }
406}
407
408#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
409#[serde(tag = "type")]
410pub enum MessagePart {
411    #[serde(rename = "text")]
412    Text { text: String },
413    #[serde(rename = "image_url")]
414    Image { image_url: ImageUrl },
415}
416
417#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
418pub struct ImageUrl {
419    pub url: String,
420    #[serde(skip_serializing_if = "Option::is_none")]
421    pub detail: Option<String>,
422}
423
424#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
425pub struct ToolCall {
426    pub id: String,
427    #[serde(flatten)]
428    pub content: ToolCallContent,
429}
430
431#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
432#[serde(tag = "type", rename_all = "lowercase")]
433pub enum ToolCallContent {
434    Function { function: FunctionContent },
435}
436
437#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
438pub struct FunctionContent {
439    pub name: String,
440    pub arguments: String,
441}
442
443#[derive(Clone, Serialize, Deserialize, Debug)]
444pub struct Response {
445    pub id: String,
446    pub object: String,
447    pub created: u64,
448    pub model: String,
449    pub choices: Vec<Choice>,
450    pub usage: Usage,
451}
452
453#[derive(Clone, Serialize, Deserialize, Debug)]
454pub struct Choice {
455    pub index: u32,
456    pub message: RequestMessage,
457    pub finish_reason: Option<String>,
458}
459
460#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
461pub struct ResponseMessageDelta {
462    pub role: Option<Role>,
463    pub content: Option<String>,
464    #[serde(default, skip_serializing_if = "is_none_or_empty")]
465    pub tool_calls: Option<Vec<ToolCallChunk>>,
466}
467
468#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
469pub struct ToolCallChunk {
470    pub index: usize,
471    pub id: Option<String>,
472
473    // There is also an optional `type` field that would determine if a
474    // function is there. Sometimes this streams in with the `function` before
475    // it streams in the `type`
476    pub function: Option<FunctionChunk>,
477}
478
479#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
480pub struct FunctionChunk {
481    pub name: Option<String>,
482    pub arguments: Option<String>,
483}
484
485#[derive(Clone, Serialize, Deserialize, Debug)]
486pub struct Usage {
487    pub prompt_tokens: u64,
488    pub completion_tokens: u64,
489    pub total_tokens: u64,
490}
491
492#[derive(Serialize, Deserialize, Debug)]
493pub struct ChoiceDelta {
494    pub index: u32,
495    pub delta: Option<ResponseMessageDelta>,
496    pub finish_reason: Option<String>,
497}
498
499#[derive(Error, Debug)]
500pub enum RequestError {
501    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
502    HttpResponseError {
503        provider: String,
504        status_code: StatusCode,
505        body: String,
506        headers: HeaderMap<HeaderValue>,
507    },
508    #[error(transparent)]
509    Other(#[from] anyhow::Error),
510}
511
512#[derive(Serialize, Deserialize, Debug)]
513pub struct ResponseStreamError {
514    message: String,
515}
516
517#[derive(Serialize, Deserialize, Debug)]
518#[serde(untagged)]
519pub enum ResponseStreamResult {
520    Ok(ResponseStreamEvent),
521    Err { error: ResponseStreamError },
522}
523
524#[derive(Serialize, Deserialize, Debug)]
525pub struct ResponseStreamEvent {
526    pub choices: Vec<ChoiceDelta>,
527    pub usage: Option<Usage>,
528}
529
530pub async fn stream_completion(
531    client: &dyn HttpClient,
532    provider_name: &str,
533    api_url: &str,
534    api_key: &str,
535    request: Request,
536) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
537    let uri = format!("{api_url}/chat/completions");
538    let request_builder = HttpRequest::builder()
539        .method(Method::POST)
540        .uri(uri)
541        .header("Content-Type", "application/json")
542        .header("Authorization", format!("Bearer {}", api_key.trim()));
543
544    let request = request_builder
545        .body(AsyncBody::from(
546            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
547        ))
548        .map_err(|e| RequestError::Other(e.into()))?;
549
550    let mut response = client.send(request).await?;
551    if response.status().is_success() {
552        let reader = BufReader::new(response.into_body());
553        Ok(reader
554            .lines()
555            .filter_map(|line| async move {
556                match line {
557                    Ok(line) => {
558                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
559                        if line == "[DONE]" {
560                            None
561                        } else {
562                            match serde_json::from_str(line) {
563                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
564                                Ok(ResponseStreamResult::Err { error }) => {
565                                    Some(Err(anyhow!(error.message)))
566                                }
567                                Err(error) => {
568                                    log::error!(
569                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
570                                        Response: `{}`",
571                                        error,
572                                        line,
573                                    );
574                                    Some(Err(anyhow!(error)))
575                                }
576                            }
577                        }
578                    }
579                    Err(error) => Some(Err(anyhow!(error))),
580                }
581            })
582            .boxed())
583    } else {
584        let mut body = String::new();
585        response
586            .body_mut()
587            .read_to_string(&mut body)
588            .await
589            .map_err(|e| RequestError::Other(e.into()))?;
590
591        Err(RequestError::HttpResponseError {
592            provider: provider_name.to_owned(),
593            status_code: response.status(),
594            body,
595            headers: response.headers().clone(),
596        })
597    }
598}
599
600#[derive(Copy, Clone, Serialize, Deserialize)]
601pub enum OpenAiEmbeddingModel {
602    #[serde(rename = "text-embedding-3-small")]
603    TextEmbedding3Small,
604    #[serde(rename = "text-embedding-3-large")]
605    TextEmbedding3Large,
606}
607
608#[derive(Serialize)]
609struct OpenAiEmbeddingRequest<'a> {
610    model: OpenAiEmbeddingModel,
611    input: Vec<&'a str>,
612}
613
614#[derive(Deserialize)]
615pub struct OpenAiEmbeddingResponse {
616    pub data: Vec<OpenAiEmbedding>,
617}
618
619#[derive(Deserialize)]
620pub struct OpenAiEmbedding {
621    pub embedding: Vec<f32>,
622}
623
624pub fn embed<'a>(
625    client: &dyn HttpClient,
626    api_url: &str,
627    api_key: &str,
628    model: OpenAiEmbeddingModel,
629    texts: impl IntoIterator<Item = &'a str>,
630) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
631    let uri = format!("{api_url}/embeddings");
632
633    let request = OpenAiEmbeddingRequest {
634        model,
635        input: texts.into_iter().collect(),
636    };
637    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
638    let request = HttpRequest::builder()
639        .method(Method::POST)
640        .uri(uri)
641        .header("Content-Type", "application/json")
642        .header("Authorization", format!("Bearer {}", api_key.trim()))
643        .body(body)
644        .map(|request| client.send(request));
645
646    async move {
647        let mut response = request?.await?;
648        let mut body = String::new();
649        response.body_mut().read_to_string(&mut body).await?;
650
651        anyhow::ensure!(
652            response.status().is_success(),
653            "error during embedding, status: {:?}, body: {:?}",
654            response.status(),
655            body
656        );
657        let response: OpenAiEmbeddingResponse =
658            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
659        Ok(response)
660    }
661}