open_ai.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{
  4    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
  5    http::{HeaderMap, HeaderValue},
  6};
  7use serde::{Deserialize, Serialize};
  8use serde_json::Value;
  9pub use settings::OpenAiReasoningEffort as ReasoningEffort;
 10use std::{convert::TryFrom, future::Future};
 11use strum::EnumIter;
 12use thiserror::Error;
 13
 14pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 15
 16fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 17    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 18}
 19
 20#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 21#[serde(rename_all = "lowercase")]
 22pub enum Role {
 23    User,
 24    Assistant,
 25    System,
 26    Tool,
 27}
 28
 29impl TryFrom<String> for Role {
 30    type Error = anyhow::Error;
 31
 32    fn try_from(value: String) -> Result<Self> {
 33        match value.as_str() {
 34            "user" => Ok(Self::User),
 35            "assistant" => Ok(Self::Assistant),
 36            "system" => Ok(Self::System),
 37            "tool" => Ok(Self::Tool),
 38            _ => anyhow::bail!("invalid role '{value}'"),
 39        }
 40    }
 41}
 42
 43impl From<Role> for String {
 44    fn from(val: Role) -> Self {
 45        match val {
 46            Role::User => "user".to_owned(),
 47            Role::Assistant => "assistant".to_owned(),
 48            Role::System => "system".to_owned(),
 49            Role::Tool => "tool".to_owned(),
 50        }
 51    }
 52}
 53
 54#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 55#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 56pub enum Model {
 57    #[serde(rename = "gpt-3.5-turbo")]
 58    ThreePointFiveTurbo,
 59    #[serde(rename = "gpt-4")]
 60    Four,
 61    #[serde(rename = "gpt-4-turbo")]
 62    FourTurbo,
 63    #[serde(rename = "gpt-4o")]
 64    #[default]
 65    FourOmni,
 66    #[serde(rename = "gpt-4o-mini")]
 67    FourOmniMini,
 68    #[serde(rename = "gpt-4.1")]
 69    FourPointOne,
 70    #[serde(rename = "gpt-4.1-mini")]
 71    FourPointOneMini,
 72    #[serde(rename = "gpt-4.1-nano")]
 73    FourPointOneNano,
 74    #[serde(rename = "o1")]
 75    O1,
 76    #[serde(rename = "o3-mini")]
 77    O3Mini,
 78    #[serde(rename = "o3")]
 79    O3,
 80    #[serde(rename = "o4-mini")]
 81    O4Mini,
 82    #[serde(rename = "gpt-5")]
 83    Five,
 84    #[serde(rename = "gpt-5-mini")]
 85    FiveMini,
 86    #[serde(rename = "gpt-5-nano")]
 87    FiveNano,
 88    #[serde(rename = "gpt-5.1")]
 89    FivePointOne,
 90    #[serde(rename = "custom")]
 91    Custom {
 92        name: String,
 93        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 94        display_name: Option<String>,
 95        max_tokens: u64,
 96        max_output_tokens: Option<u64>,
 97        max_completion_tokens: Option<u64>,
 98        reasoning_effort: Option<ReasoningEffort>,
 99    },
100}
101
102impl Model {
103    pub fn default_fast() -> Self {
104        // TODO: Replace with FiveMini since all other models are deprecated
105        Self::FourPointOneMini
106    }
107
108    pub fn from_id(id: &str) -> Result<Self> {
109        match id {
110            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
111            "gpt-4" => Ok(Self::Four),
112            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
113            "gpt-4o" => Ok(Self::FourOmni),
114            "gpt-4o-mini" => Ok(Self::FourOmniMini),
115            "gpt-4.1" => Ok(Self::FourPointOne),
116            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
117            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
118            "o1" => Ok(Self::O1),
119            "o3-mini" => Ok(Self::O3Mini),
120            "o3" => Ok(Self::O3),
121            "o4-mini" => Ok(Self::O4Mini),
122            "gpt-5" => Ok(Self::Five),
123            "gpt-5-mini" => Ok(Self::FiveMini),
124            "gpt-5-nano" => Ok(Self::FiveNano),
125            "gpt-5.1" => Ok(Self::FivePointOne),
126            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
127        }
128    }
129
130    pub fn id(&self) -> &str {
131        match self {
132            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
133            Self::Four => "gpt-4",
134            Self::FourTurbo => "gpt-4-turbo",
135            Self::FourOmni => "gpt-4o",
136            Self::FourOmniMini => "gpt-4o-mini",
137            Self::FourPointOne => "gpt-4.1",
138            Self::FourPointOneMini => "gpt-4.1-mini",
139            Self::FourPointOneNano => "gpt-4.1-nano",
140            Self::O1 => "o1",
141            Self::O3Mini => "o3-mini",
142            Self::O3 => "o3",
143            Self::O4Mini => "o4-mini",
144            Self::Five => "gpt-5",
145            Self::FiveMini => "gpt-5-mini",
146            Self::FiveNano => "gpt-5-nano",
147            Self::FivePointOne => "gpt-5.1",
148            Self::Custom { name, .. } => name,
149        }
150    }
151
152    pub fn display_name(&self) -> &str {
153        match self {
154            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
155            Self::Four => "gpt-4",
156            Self::FourTurbo => "gpt-4-turbo",
157            Self::FourOmni => "gpt-4o",
158            Self::FourOmniMini => "gpt-4o-mini",
159            Self::FourPointOne => "gpt-4.1",
160            Self::FourPointOneMini => "gpt-4.1-mini",
161            Self::FourPointOneNano => "gpt-4.1-nano",
162            Self::O1 => "o1",
163            Self::O3Mini => "o3-mini",
164            Self::O3 => "o3",
165            Self::O4Mini => "o4-mini",
166            Self::Five => "gpt-5",
167            Self::FiveMini => "gpt-5-mini",
168            Self::FiveNano => "gpt-5-nano",
169            Self::FivePointOne => "gpt-5.1",
170            Self::Custom {
171                name, display_name, ..
172            } => display_name.as_ref().unwrap_or(name),
173        }
174    }
175
176    pub fn max_token_count(&self) -> u64 {
177        match self {
178            Self::ThreePointFiveTurbo => 16_385,
179            Self::Four => 8_192,
180            Self::FourTurbo => 128_000,
181            Self::FourOmni => 128_000,
182            Self::FourOmniMini => 128_000,
183            Self::FourPointOne => 1_047_576,
184            Self::FourPointOneMini => 1_047_576,
185            Self::FourPointOneNano => 1_047_576,
186            Self::O1 => 200_000,
187            Self::O3Mini => 200_000,
188            Self::O3 => 200_000,
189            Self::O4Mini => 200_000,
190            Self::Five => 272_000,
191            Self::FiveMini => 272_000,
192            Self::FiveNano => 272_000,
193            Self::FivePointOne => 400_000,
194            Self::Custom { max_tokens, .. } => *max_tokens,
195        }
196    }
197
198    pub fn max_output_tokens(&self) -> Option<u64> {
199        match self {
200            Self::Custom {
201                max_output_tokens, ..
202            } => *max_output_tokens,
203            Self::ThreePointFiveTurbo => Some(4_096),
204            Self::Four => Some(8_192),
205            Self::FourTurbo => Some(4_096),
206            Self::FourOmni => Some(16_384),
207            Self::FourOmniMini => Some(16_384),
208            Self::FourPointOne => Some(32_768),
209            Self::FourPointOneMini => Some(32_768),
210            Self::FourPointOneNano => Some(32_768),
211            Self::O1 => Some(100_000),
212            Self::O3Mini => Some(100_000),
213            Self::O3 => Some(100_000),
214            Self::O4Mini => Some(100_000),
215            Self::Five => Some(128_000),
216            Self::FiveMini => Some(128_000),
217            Self::FiveNano => Some(128_000),
218            Self::FivePointOne => Some(128_000),
219        }
220    }
221
222    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
223        match self {
224            Self::Custom {
225                reasoning_effort, ..
226            } => reasoning_effort.to_owned(),
227            _ => None,
228        }
229    }
230
231    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
232    ///
233    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
234    pub fn supports_parallel_tool_calls(&self) -> bool {
235        match self {
236            Self::ThreePointFiveTurbo
237            | Self::Four
238            | Self::FourTurbo
239            | Self::FourOmni
240            | Self::FourOmniMini
241            | Self::FourPointOne
242            | Self::FourPointOneMini
243            | Self::FourPointOneNano
244            | Self::Five
245            | Self::FiveMini
246            | Self::FivePointOne
247            | Self::FiveNano => true,
248            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
249        }
250    }
251
252    /// Returns whether the given model supports the `prompt_cache_key` parameter.
253    ///
254    /// If the model does not support the parameter, do not pass it up.
255    pub fn supports_prompt_cache_key(&self) -> bool {
256        true
257    }
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261pub struct Request {
262    pub model: String,
263    pub messages: Vec<RequestMessage>,
264    pub stream: bool,
265    #[serde(default, skip_serializing_if = "Option::is_none")]
266    pub max_completion_tokens: Option<u64>,
267    #[serde(default, skip_serializing_if = "Vec::is_empty")]
268    pub stop: Vec<String>,
269    pub temperature: f32,
270    #[serde(default, skip_serializing_if = "Option::is_none")]
271    pub tool_choice: Option<ToolChoice>,
272    /// Whether to enable parallel function calling during tool use.
273    #[serde(default, skip_serializing_if = "Option::is_none")]
274    pub parallel_tool_calls: Option<bool>,
275    #[serde(default, skip_serializing_if = "Vec::is_empty")]
276    pub tools: Vec<ToolDefinition>,
277    #[serde(default, skip_serializing_if = "Option::is_none")]
278    pub prompt_cache_key: Option<String>,
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub reasoning_effort: Option<ReasoningEffort>,
281}
282
283#[derive(Debug, Serialize, Deserialize)]
284#[serde(rename_all = "lowercase")]
285pub enum ToolChoice {
286    Auto,
287    Required,
288    None,
289    #[serde(untagged)]
290    Other(ToolDefinition),
291}
292
293#[derive(Clone, Deserialize, Serialize, Debug)]
294#[serde(tag = "type", rename_all = "snake_case")]
295pub enum ToolDefinition {
296    #[allow(dead_code)]
297    Function { function: FunctionDefinition },
298}
299
300#[derive(Clone, Debug, Serialize, Deserialize)]
301pub struct FunctionDefinition {
302    pub name: String,
303    pub description: Option<String>,
304    pub parameters: Option<Value>,
305}
306
307#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
308#[serde(tag = "role", rename_all = "lowercase")]
309pub enum RequestMessage {
310    Assistant {
311        content: Option<MessageContent>,
312        #[serde(default, skip_serializing_if = "Vec::is_empty")]
313        tool_calls: Vec<ToolCall>,
314    },
315    User {
316        content: MessageContent,
317    },
318    System {
319        content: MessageContent,
320    },
321    Tool {
322        content: MessageContent,
323        tool_call_id: String,
324    },
325}
326
327#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
328#[serde(untagged)]
329pub enum MessageContent {
330    Plain(String),
331    Multipart(Vec<MessagePart>),
332}
333
334impl MessageContent {
335    pub fn empty() -> Self {
336        MessageContent::Multipart(vec![])
337    }
338
339    pub fn push_part(&mut self, part: MessagePart) {
340        match self {
341            MessageContent::Plain(text) => {
342                *self =
343                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
344            }
345            MessageContent::Multipart(parts) if parts.is_empty() => match part {
346                MessagePart::Text { text } => *self = MessageContent::Plain(text),
347                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
348            },
349            MessageContent::Multipart(parts) => parts.push(part),
350        }
351    }
352}
353
354impl From<Vec<MessagePart>> for MessageContent {
355    fn from(mut parts: Vec<MessagePart>) -> Self {
356        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
357            MessageContent::Plain(std::mem::take(text))
358        } else {
359            MessageContent::Multipart(parts)
360        }
361    }
362}
363
364#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
365#[serde(tag = "type")]
366pub enum MessagePart {
367    #[serde(rename = "text")]
368    Text { text: String },
369    #[serde(rename = "image_url")]
370    Image { image_url: ImageUrl },
371}
372
373#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
374pub struct ImageUrl {
375    pub url: String,
376    #[serde(skip_serializing_if = "Option::is_none")]
377    pub detail: Option<String>,
378}
379
380#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
381pub struct ToolCall {
382    pub id: String,
383    #[serde(flatten)]
384    pub content: ToolCallContent,
385}
386
387#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
388#[serde(tag = "type", rename_all = "lowercase")]
389pub enum ToolCallContent {
390    Function { function: FunctionContent },
391}
392
393#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
394pub struct FunctionContent {
395    pub name: String,
396    pub arguments: String,
397}
398
399#[derive(Clone, Serialize, Deserialize, Debug)]
400pub struct Response {
401    pub id: String,
402    pub object: String,
403    pub created: u64,
404    pub model: String,
405    pub choices: Vec<Choice>,
406    pub usage: Usage,
407}
408
409#[derive(Clone, Serialize, Deserialize, Debug)]
410pub struct Choice {
411    pub index: u32,
412    pub message: RequestMessage,
413    pub finish_reason: Option<String>,
414}
415
416#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
417pub struct ResponseMessageDelta {
418    pub role: Option<Role>,
419    pub content: Option<String>,
420    #[serde(default, skip_serializing_if = "is_none_or_empty")]
421    pub tool_calls: Option<Vec<ToolCallChunk>>,
422}
423
424#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
425pub struct ToolCallChunk {
426    pub index: usize,
427    pub id: Option<String>,
428
429    // There is also an optional `type` field that would determine if a
430    // function is there. Sometimes this streams in with the `function` before
431    // it streams in the `type`
432    pub function: Option<FunctionChunk>,
433}
434
435#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
436pub struct FunctionChunk {
437    pub name: Option<String>,
438    pub arguments: Option<String>,
439}
440
441#[derive(Clone, Serialize, Deserialize, Debug)]
442pub struct Usage {
443    pub prompt_tokens: u64,
444    pub completion_tokens: u64,
445    pub total_tokens: u64,
446}
447
448#[derive(Serialize, Deserialize, Debug)]
449pub struct ChoiceDelta {
450    pub index: u32,
451    pub delta: Option<ResponseMessageDelta>,
452    pub finish_reason: Option<String>,
453}
454
455#[derive(Error, Debug)]
456pub enum RequestError {
457    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
458    HttpResponseError {
459        provider: String,
460        status_code: StatusCode,
461        body: String,
462        headers: HeaderMap<HeaderValue>,
463    },
464    #[error(transparent)]
465    Other(#[from] anyhow::Error),
466}
467
468#[derive(Serialize, Deserialize, Debug)]
469pub struct ResponseStreamError {
470    message: String,
471}
472
473#[derive(Serialize, Deserialize, Debug)]
474#[serde(untagged)]
475pub enum ResponseStreamResult {
476    Ok(ResponseStreamEvent),
477    Err { error: ResponseStreamError },
478}
479
480#[derive(Serialize, Deserialize, Debug)]
481pub struct ResponseStreamEvent {
482    pub choices: Vec<ChoiceDelta>,
483    pub usage: Option<Usage>,
484}
485
486pub async fn stream_completion(
487    client: &dyn HttpClient,
488    provider_name: &str,
489    api_url: &str,
490    api_key: &str,
491    request: Request,
492) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
493    let uri = format!("{api_url}/chat/completions");
494    let request_builder = HttpRequest::builder()
495        .method(Method::POST)
496        .uri(uri)
497        .header("Content-Type", "application/json")
498        .header("Authorization", format!("Bearer {}", api_key.trim()));
499
500    let request = request_builder
501        .body(AsyncBody::from(
502            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
503        ))
504        .map_err(|e| RequestError::Other(e.into()))?;
505
506    let mut response = client.send(request).await?;
507    if response.status().is_success() {
508        let reader = BufReader::new(response.into_body());
509        Ok(reader
510            .lines()
511            .filter_map(|line| async move {
512                match line {
513                    Ok(line) => {
514                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
515                        if line == "[DONE]" {
516                            None
517                        } else {
518                            match serde_json::from_str(line) {
519                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
520                                Ok(ResponseStreamResult::Err { error }) => {
521                                    Some(Err(anyhow!(error.message)))
522                                }
523                                Err(error) => {
524                                    log::error!(
525                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
526                                        Response: `{}`",
527                                        error,
528                                        line,
529                                    );
530                                    Some(Err(anyhow!(error)))
531                                }
532                            }
533                        }
534                    }
535                    Err(error) => Some(Err(anyhow!(error))),
536                }
537            })
538            .boxed())
539    } else {
540        let mut body = String::new();
541        response
542            .body_mut()
543            .read_to_string(&mut body)
544            .await
545            .map_err(|e| RequestError::Other(e.into()))?;
546
547        Err(RequestError::HttpResponseError {
548            provider: provider_name.to_owned(),
549            status_code: response.status(),
550            body,
551            headers: response.headers().clone(),
552        })
553    }
554}
555
556#[derive(Copy, Clone, Serialize, Deserialize)]
557pub enum OpenAiEmbeddingModel {
558    #[serde(rename = "text-embedding-3-small")]
559    TextEmbedding3Small,
560    #[serde(rename = "text-embedding-3-large")]
561    TextEmbedding3Large,
562}
563
564#[derive(Serialize)]
565struct OpenAiEmbeddingRequest<'a> {
566    model: OpenAiEmbeddingModel,
567    input: Vec<&'a str>,
568}
569
570#[derive(Deserialize)]
571pub struct OpenAiEmbeddingResponse {
572    pub data: Vec<OpenAiEmbedding>,
573}
574
575#[derive(Deserialize)]
576pub struct OpenAiEmbedding {
577    pub embedding: Vec<f32>,
578}
579
580pub fn embed<'a>(
581    client: &dyn HttpClient,
582    api_url: &str,
583    api_key: &str,
584    model: OpenAiEmbeddingModel,
585    texts: impl IntoIterator<Item = &'a str>,
586) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
587    let uri = format!("{api_url}/embeddings");
588
589    let request = OpenAiEmbeddingRequest {
590        model,
591        input: texts.into_iter().collect(),
592    };
593    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
594    let request = HttpRequest::builder()
595        .method(Method::POST)
596        .uri(uri)
597        .header("Content-Type", "application/json")
598        .header("Authorization", format!("Bearer {}", api_key.trim()))
599        .body(body)
600        .map(|request| client.send(request));
601
602    async move {
603        let mut response = request?.await?;
604        let mut body = String::new();
605        response.body_mut().read_to_string(&mut body).await?;
606
607        anyhow::ensure!(
608            response.status().is_success(),
609            "error during embedding, status: {:?}, body: {:?}",
610            response.status(),
611            body
612        );
613        let response: OpenAiEmbeddingResponse =
614            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
615        Ok(response)
616    }
617}