open_ai.rs

  1pub mod responses;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  5use http_client::{
  6    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
  7    http::{HeaderMap, HeaderValue},
  8};
  9use serde::{Deserialize, Serialize};
 10use serde_json::Value;
 11pub use settings::OpenAiReasoningEffort as ReasoningEffort;
 12use std::{convert::TryFrom, future::Future};
 13use strum::EnumIter;
 14use thiserror::Error;
 15
 16pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 17
 18fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 19    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 20}
 21
 22#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 23#[serde(rename_all = "lowercase")]
 24pub enum Role {
 25    User,
 26    Assistant,
 27    System,
 28    Tool,
 29}
 30
 31impl TryFrom<String> for Role {
 32    type Error = anyhow::Error;
 33
 34    fn try_from(value: String) -> Result<Self> {
 35        match value.as_str() {
 36            "user" => Ok(Self::User),
 37            "assistant" => Ok(Self::Assistant),
 38            "system" => Ok(Self::System),
 39            "tool" => Ok(Self::Tool),
 40            _ => anyhow::bail!("invalid role '{value}'"),
 41        }
 42    }
 43}
 44
 45impl From<Role> for String {
 46    fn from(val: Role) -> Self {
 47        match val {
 48            Role::User => "user".to_owned(),
 49            Role::Assistant => "assistant".to_owned(),
 50            Role::System => "system".to_owned(),
 51            Role::Tool => "tool".to_owned(),
 52        }
 53    }
 54}
 55
 56#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 57#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 58pub enum Model {
 59    #[serde(rename = "gpt-3.5-turbo")]
 60    ThreePointFiveTurbo,
 61    #[serde(rename = "gpt-4")]
 62    Four,
 63    #[serde(rename = "gpt-4-turbo")]
 64    FourTurbo,
 65    #[serde(rename = "gpt-4o")]
 66    #[default]
 67    FourOmni,
 68    #[serde(rename = "gpt-4o-mini")]
 69    FourOmniMini,
 70    #[serde(rename = "gpt-4.1")]
 71    FourPointOne,
 72    #[serde(rename = "gpt-4.1-mini")]
 73    FourPointOneMini,
 74    #[serde(rename = "gpt-4.1-nano")]
 75    FourPointOneNano,
 76    #[serde(rename = "o1")]
 77    O1,
 78    #[serde(rename = "o3-mini")]
 79    O3Mini,
 80    #[serde(rename = "o3")]
 81    O3,
 82    #[serde(rename = "o4-mini")]
 83    O4Mini,
 84    #[serde(rename = "gpt-5")]
 85    Five,
 86    #[serde(rename = "gpt-5-codex")]
 87    FiveCodex,
 88    #[serde(rename = "gpt-5-mini")]
 89    FiveMini,
 90    #[serde(rename = "gpt-5-nano")]
 91    FiveNano,
 92    #[serde(rename = "gpt-5.1")]
 93    FivePointOne,
 94    #[serde(rename = "gpt-5.2")]
 95    FivePointTwo,
 96    #[serde(rename = "custom")]
 97    Custom {
 98        name: String,
 99        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
100        display_name: Option<String>,
101        max_tokens: u64,
102        max_output_tokens: Option<u64>,
103        max_completion_tokens: Option<u64>,
104        reasoning_effort: Option<ReasoningEffort>,
105        #[serde(default = "default_supports_chat_completions")]
106        supports_chat_completions: bool,
107    },
108}
109
110const fn default_supports_chat_completions() -> bool {
111    true
112}
113
114impl Model {
115    pub fn default_fast() -> Self {
116        // TODO: Replace with FiveMini since all other models are deprecated
117        Self::FourPointOneMini
118    }
119
120    pub fn from_id(id: &str) -> Result<Self> {
121        match id {
122            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
123            "gpt-4" => Ok(Self::Four),
124            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
125            "gpt-4o" => Ok(Self::FourOmni),
126            "gpt-4o-mini" => Ok(Self::FourOmniMini),
127            "gpt-4.1" => Ok(Self::FourPointOne),
128            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
129            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
130            "o1" => Ok(Self::O1),
131            "o3-mini" => Ok(Self::O3Mini),
132            "o3" => Ok(Self::O3),
133            "o4-mini" => Ok(Self::O4Mini),
134            "gpt-5" => Ok(Self::Five),
135            "gpt-5-codex" => Ok(Self::FiveCodex),
136            "gpt-5-mini" => Ok(Self::FiveMini),
137            "gpt-5-nano" => Ok(Self::FiveNano),
138            "gpt-5.1" => Ok(Self::FivePointOne),
139            "gpt-5.2" => Ok(Self::FivePointTwo),
140            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
141        }
142    }
143
144    pub fn id(&self) -> &str {
145        match self {
146            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
147            Self::Four => "gpt-4",
148            Self::FourTurbo => "gpt-4-turbo",
149            Self::FourOmni => "gpt-4o",
150            Self::FourOmniMini => "gpt-4o-mini",
151            Self::FourPointOne => "gpt-4.1",
152            Self::FourPointOneMini => "gpt-4.1-mini",
153            Self::FourPointOneNano => "gpt-4.1-nano",
154            Self::O1 => "o1",
155            Self::O3Mini => "o3-mini",
156            Self::O3 => "o3",
157            Self::O4Mini => "o4-mini",
158            Self::Five => "gpt-5",
159            Self::FiveCodex => "gpt-5-codex",
160            Self::FiveMini => "gpt-5-mini",
161            Self::FiveNano => "gpt-5-nano",
162            Self::FivePointOne => "gpt-5.1",
163            Self::FivePointTwo => "gpt-5.2",
164            Self::Custom { name, .. } => name,
165        }
166    }
167
168    pub fn display_name(&self) -> &str {
169        match self {
170            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
171            Self::Four => "gpt-4",
172            Self::FourTurbo => "gpt-4-turbo",
173            Self::FourOmni => "gpt-4o",
174            Self::FourOmniMini => "gpt-4o-mini",
175            Self::FourPointOne => "gpt-4.1",
176            Self::FourPointOneMini => "gpt-4.1-mini",
177            Self::FourPointOneNano => "gpt-4.1-nano",
178            Self::O1 => "o1",
179            Self::O3Mini => "o3-mini",
180            Self::O3 => "o3",
181            Self::O4Mini => "o4-mini",
182            Self::Five => "gpt-5",
183            Self::FiveCodex => "gpt-5-codex",
184            Self::FiveMini => "gpt-5-mini",
185            Self::FiveNano => "gpt-5-nano",
186            Self::FivePointOne => "gpt-5.1",
187            Self::FivePointTwo => "gpt-5.2",
188            Self::Custom {
189                name, display_name, ..
190            } => display_name.as_ref().unwrap_or(name),
191        }
192    }
193
194    pub fn max_token_count(&self) -> u64 {
195        match self {
196            Self::ThreePointFiveTurbo => 16_385,
197            Self::Four => 8_192,
198            Self::FourTurbo => 128_000,
199            Self::FourOmni => 128_000,
200            Self::FourOmniMini => 128_000,
201            Self::FourPointOne => 1_047_576,
202            Self::FourPointOneMini => 1_047_576,
203            Self::FourPointOneNano => 1_047_576,
204            Self::O1 => 200_000,
205            Self::O3Mini => 200_000,
206            Self::O3 => 200_000,
207            Self::O4Mini => 200_000,
208            Self::Five => 272_000,
209            Self::FiveCodex => 272_000,
210            Self::FiveMini => 272_000,
211            Self::FiveNano => 272_000,
212            Self::FivePointOne => 400_000,
213            Self::FivePointTwo => 400_000,
214            Self::Custom { max_tokens, .. } => *max_tokens,
215        }
216    }
217
218    pub fn max_output_tokens(&self) -> Option<u64> {
219        match self {
220            Self::Custom {
221                max_output_tokens, ..
222            } => *max_output_tokens,
223            Self::ThreePointFiveTurbo => Some(4_096),
224            Self::Four => Some(8_192),
225            Self::FourTurbo => Some(4_096),
226            Self::FourOmni => Some(16_384),
227            Self::FourOmniMini => Some(16_384),
228            Self::FourPointOne => Some(32_768),
229            Self::FourPointOneMini => Some(32_768),
230            Self::FourPointOneNano => Some(32_768),
231            Self::O1 => Some(100_000),
232            Self::O3Mini => Some(100_000),
233            Self::O3 => Some(100_000),
234            Self::O4Mini => Some(100_000),
235            Self::Five => Some(128_000),
236            Self::FiveCodex => Some(128_000),
237            Self::FiveMini => Some(128_000),
238            Self::FiveNano => Some(128_000),
239            Self::FivePointOne => Some(128_000),
240            Self::FivePointTwo => Some(128_000),
241        }
242    }
243
244    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
245        match self {
246            Self::Custom {
247                reasoning_effort, ..
248            } => reasoning_effort.to_owned(),
249            _ => None,
250        }
251    }
252
253    pub fn supports_chat_completions(&self) -> bool {
254        match self {
255            Self::Custom {
256                supports_chat_completions,
257                ..
258            } => *supports_chat_completions,
259            Self::FiveCodex => false,
260            _ => true,
261        }
262    }
263
264    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
265    ///
266    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
267    pub fn supports_parallel_tool_calls(&self) -> bool {
268        match self {
269            Self::ThreePointFiveTurbo
270            | Self::Four
271            | Self::FourTurbo
272            | Self::FourOmni
273            | Self::FourOmniMini
274            | Self::FourPointOne
275            | Self::FourPointOneMini
276            | Self::FourPointOneNano
277            | Self::Five
278            | Self::FiveCodex
279            | Self::FiveMini
280            | Self::FivePointOne
281            | Self::FivePointTwo
282            | Self::FiveNano => true,
283            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
284        }
285    }
286
287    /// Returns whether the given model supports the `prompt_cache_key` parameter.
288    ///
289    /// If the model does not support the parameter, do not pass it up.
290    pub fn supports_prompt_cache_key(&self) -> bool {
291        true
292    }
293}
294
295#[derive(Debug, Serialize, Deserialize)]
296pub struct Request {
297    pub model: String,
298    pub messages: Vec<RequestMessage>,
299    pub stream: bool,
300    #[serde(default, skip_serializing_if = "Option::is_none")]
301    pub max_completion_tokens: Option<u64>,
302    #[serde(default, skip_serializing_if = "Vec::is_empty")]
303    pub stop: Vec<String>,
304    #[serde(default, skip_serializing_if = "Option::is_none")]
305    pub temperature: Option<f32>,
306    #[serde(default, skip_serializing_if = "Option::is_none")]
307    pub tool_choice: Option<ToolChoice>,
308    /// Whether to enable parallel function calling during tool use.
309    #[serde(default, skip_serializing_if = "Option::is_none")]
310    pub parallel_tool_calls: Option<bool>,
311    #[serde(default, skip_serializing_if = "Vec::is_empty")]
312    pub tools: Vec<ToolDefinition>,
313    #[serde(default, skip_serializing_if = "Option::is_none")]
314    pub prompt_cache_key: Option<String>,
315    #[serde(default, skip_serializing_if = "Option::is_none")]
316    pub reasoning_effort: Option<ReasoningEffort>,
317}
318
319#[derive(Debug, Serialize, Deserialize)]
320#[serde(rename_all = "lowercase")]
321pub enum ToolChoice {
322    Auto,
323    Required,
324    None,
325    #[serde(untagged)]
326    Other(ToolDefinition),
327}
328
329#[derive(Clone, Deserialize, Serialize, Debug)]
330#[serde(tag = "type", rename_all = "snake_case")]
331pub enum ToolDefinition {
332    #[allow(dead_code)]
333    Function { function: FunctionDefinition },
334}
335
336#[derive(Clone, Debug, Serialize, Deserialize)]
337pub struct FunctionDefinition {
338    pub name: String,
339    pub description: Option<String>,
340    pub parameters: Option<Value>,
341}
342
343#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
344#[serde(tag = "role", rename_all = "lowercase")]
345pub enum RequestMessage {
346    Assistant {
347        content: Option<MessageContent>,
348        #[serde(default, skip_serializing_if = "Vec::is_empty")]
349        tool_calls: Vec<ToolCall>,
350    },
351    User {
352        content: MessageContent,
353    },
354    System {
355        content: MessageContent,
356    },
357    Tool {
358        content: MessageContent,
359        tool_call_id: String,
360    },
361}
362
363#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
364#[serde(untagged)]
365pub enum MessageContent {
366    Plain(String),
367    Multipart(Vec<MessagePart>),
368}
369
370impl MessageContent {
371    pub fn empty() -> Self {
372        MessageContent::Multipart(vec![])
373    }
374
375    pub fn push_part(&mut self, part: MessagePart) {
376        match self {
377            MessageContent::Plain(text) => {
378                *self =
379                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
380            }
381            MessageContent::Multipart(parts) if parts.is_empty() => match part {
382                MessagePart::Text { text } => *self = MessageContent::Plain(text),
383                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
384            },
385            MessageContent::Multipart(parts) => parts.push(part),
386        }
387    }
388}
389
390impl From<Vec<MessagePart>> for MessageContent {
391    fn from(mut parts: Vec<MessagePart>) -> Self {
392        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
393            MessageContent::Plain(std::mem::take(text))
394        } else {
395            MessageContent::Multipart(parts)
396        }
397    }
398}
399
400#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
401#[serde(tag = "type")]
402pub enum MessagePart {
403    #[serde(rename = "text")]
404    Text { text: String },
405    #[serde(rename = "image_url")]
406    Image { image_url: ImageUrl },
407}
408
409#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
410pub struct ImageUrl {
411    pub url: String,
412    #[serde(skip_serializing_if = "Option::is_none")]
413    pub detail: Option<String>,
414}
415
416#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
417pub struct ToolCall {
418    pub id: String,
419    #[serde(flatten)]
420    pub content: ToolCallContent,
421}
422
423#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
424#[serde(tag = "type", rename_all = "lowercase")]
425pub enum ToolCallContent {
426    Function { function: FunctionContent },
427}
428
429#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
430pub struct FunctionContent {
431    pub name: String,
432    pub arguments: String,
433}
434
435#[derive(Clone, Serialize, Deserialize, Debug)]
436pub struct Response {
437    pub id: String,
438    pub object: String,
439    pub created: u64,
440    pub model: String,
441    pub choices: Vec<Choice>,
442    pub usage: Usage,
443}
444
445#[derive(Clone, Serialize, Deserialize, Debug)]
446pub struct Choice {
447    pub index: u32,
448    pub message: RequestMessage,
449    pub finish_reason: Option<String>,
450}
451
452#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
453pub struct ResponseMessageDelta {
454    pub role: Option<Role>,
455    pub content: Option<String>,
456    #[serde(default, skip_serializing_if = "is_none_or_empty")]
457    pub tool_calls: Option<Vec<ToolCallChunk>>,
458}
459
460#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
461pub struct ToolCallChunk {
462    pub index: usize,
463    pub id: Option<String>,
464
465    // There is also an optional `type` field that would determine if a
466    // function is there. Sometimes this streams in with the `function` before
467    // it streams in the `type`
468    pub function: Option<FunctionChunk>,
469}
470
471#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
472pub struct FunctionChunk {
473    pub name: Option<String>,
474    pub arguments: Option<String>,
475}
476
477#[derive(Clone, Serialize, Deserialize, Debug)]
478pub struct Usage {
479    pub prompt_tokens: u64,
480    pub completion_tokens: u64,
481    pub total_tokens: u64,
482}
483
484#[derive(Serialize, Deserialize, Debug)]
485pub struct ChoiceDelta {
486    pub index: u32,
487    pub delta: Option<ResponseMessageDelta>,
488    pub finish_reason: Option<String>,
489}
490
491#[derive(Error, Debug)]
492pub enum RequestError {
493    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
494    HttpResponseError {
495        provider: String,
496        status_code: StatusCode,
497        body: String,
498        headers: HeaderMap<HeaderValue>,
499    },
500    #[error(transparent)]
501    Other(#[from] anyhow::Error),
502}
503
504#[derive(Serialize, Deserialize, Debug)]
505pub struct ResponseStreamError {
506    message: String,
507}
508
509#[derive(Serialize, Deserialize, Debug)]
510#[serde(untagged)]
511pub enum ResponseStreamResult {
512    Ok(ResponseStreamEvent),
513    Err { error: ResponseStreamError },
514}
515
516#[derive(Serialize, Deserialize, Debug)]
517pub struct ResponseStreamEvent {
518    pub choices: Vec<ChoiceDelta>,
519    pub usage: Option<Usage>,
520}
521
522pub async fn stream_completion(
523    client: &dyn HttpClient,
524    provider_name: &str,
525    api_url: &str,
526    api_key: &str,
527    request: Request,
528) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
529    let uri = format!("{api_url}/chat/completions");
530    let request_builder = HttpRequest::builder()
531        .method(Method::POST)
532        .uri(uri)
533        .header("Content-Type", "application/json")
534        .header("Authorization", format!("Bearer {}", api_key.trim()));
535
536    let request = request_builder
537        .body(AsyncBody::from(
538            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
539        ))
540        .map_err(|e| RequestError::Other(e.into()))?;
541
542    let mut response = client.send(request).await?;
543    if response.status().is_success() {
544        let reader = BufReader::new(response.into_body());
545        Ok(reader
546            .lines()
547            .filter_map(|line| async move {
548                match line {
549                    Ok(line) => {
550                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
551                        if line == "[DONE]" {
552                            None
553                        } else {
554                            match serde_json::from_str(line) {
555                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
556                                Ok(ResponseStreamResult::Err { error }) => {
557                                    Some(Err(anyhow!(error.message)))
558                                }
559                                Err(error) => {
560                                    log::error!(
561                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
562                                        Response: `{}`",
563                                        error,
564                                        line,
565                                    );
566                                    Some(Err(anyhow!(error)))
567                                }
568                            }
569                        }
570                    }
571                    Err(error) => Some(Err(anyhow!(error))),
572                }
573            })
574            .boxed())
575    } else {
576        let mut body = String::new();
577        response
578            .body_mut()
579            .read_to_string(&mut body)
580            .await
581            .map_err(|e| RequestError::Other(e.into()))?;
582
583        Err(RequestError::HttpResponseError {
584            provider: provider_name.to_owned(),
585            status_code: response.status(),
586            body,
587            headers: response.headers().clone(),
588        })
589    }
590}
591
592#[derive(Copy, Clone, Serialize, Deserialize)]
593pub enum OpenAiEmbeddingModel {
594    #[serde(rename = "text-embedding-3-small")]
595    TextEmbedding3Small,
596    #[serde(rename = "text-embedding-3-large")]
597    TextEmbedding3Large,
598}
599
600#[derive(Serialize)]
601struct OpenAiEmbeddingRequest<'a> {
602    model: OpenAiEmbeddingModel,
603    input: Vec<&'a str>,
604}
605
606#[derive(Deserialize)]
607pub struct OpenAiEmbeddingResponse {
608    pub data: Vec<OpenAiEmbedding>,
609}
610
611#[derive(Deserialize)]
612pub struct OpenAiEmbedding {
613    pub embedding: Vec<f32>,
614}
615
616pub fn embed<'a>(
617    client: &dyn HttpClient,
618    api_url: &str,
619    api_key: &str,
620    model: OpenAiEmbeddingModel,
621    texts: impl IntoIterator<Item = &'a str>,
622) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
623    let uri = format!("{api_url}/embeddings");
624
625    let request = OpenAiEmbeddingRequest {
626        model,
627        input: texts.into_iter().collect(),
628    };
629    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
630    let request = HttpRequest::builder()
631        .method(Method::POST)
632        .uri(uri)
633        .header("Content-Type", "application/json")
634        .header("Authorization", format!("Bearer {}", api_key.trim()))
635        .body(body)
636        .map(|request| client.send(request));
637
638    async move {
639        let mut response = request?.await?;
640        let mut body = String::new();
641        response.body_mut().read_to_string(&mut body).await?;
642
643        anyhow::ensure!(
644            response.status().is_success(),
645            "error during embedding, status: {:?}, body: {:?}",
646            response.status(),
647            body
648        );
649        let response: OpenAiEmbeddingResponse =
650            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
651        Ok(response)
652    }
653}