open_ai.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::OpenAiReasoningEffort as ReasoningEffort;
  7use std::{convert::TryFrom, future::Future};
  8use strum::EnumIter;
  9
 10pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 11
 12fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 13    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 14}
 15
 16#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 17#[serde(rename_all = "lowercase")]
 18pub enum Role {
 19    User,
 20    Assistant,
 21    System,
 22    Tool,
 23}
 24
 25impl TryFrom<String> for Role {
 26    type Error = anyhow::Error;
 27
 28    fn try_from(value: String) -> Result<Self> {
 29        match value.as_str() {
 30            "user" => Ok(Self::User),
 31            "assistant" => Ok(Self::Assistant),
 32            "system" => Ok(Self::System),
 33            "tool" => Ok(Self::Tool),
 34            _ => anyhow::bail!("invalid role '{value}'"),
 35        }
 36    }
 37}
 38
 39impl From<Role> for String {
 40    fn from(val: Role) -> Self {
 41        match val {
 42            Role::User => "user".to_owned(),
 43            Role::Assistant => "assistant".to_owned(),
 44            Role::System => "system".to_owned(),
 45            Role::Tool => "tool".to_owned(),
 46        }
 47    }
 48}
 49
 50#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 51#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 52pub enum Model {
 53    #[serde(rename = "gpt-3.5-turbo")]
 54    ThreePointFiveTurbo,
 55    #[serde(rename = "gpt-4")]
 56    Four,
 57    #[serde(rename = "gpt-4-turbo")]
 58    FourTurbo,
 59    #[serde(rename = "gpt-4o")]
 60    #[default]
 61    FourOmni,
 62    #[serde(rename = "gpt-4o-mini")]
 63    FourOmniMini,
 64    #[serde(rename = "gpt-4.1")]
 65    FourPointOne,
 66    #[serde(rename = "gpt-4.1-mini")]
 67    FourPointOneMini,
 68    #[serde(rename = "gpt-4.1-nano")]
 69    FourPointOneNano,
 70    #[serde(rename = "o1")]
 71    O1,
 72    #[serde(rename = "o3-mini")]
 73    O3Mini,
 74    #[serde(rename = "o3")]
 75    O3,
 76    #[serde(rename = "o4-mini")]
 77    O4Mini,
 78    #[serde(rename = "gpt-5")]
 79    Five,
 80    #[serde(rename = "gpt-5-mini")]
 81    FiveMini,
 82    #[serde(rename = "gpt-5-nano")]
 83    FiveNano,
 84
 85    #[serde(rename = "custom")]
 86    Custom {
 87        name: String,
 88        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 89        display_name: Option<String>,
 90        max_tokens: u64,
 91        max_output_tokens: Option<u64>,
 92        max_completion_tokens: Option<u64>,
 93        reasoning_effort: Option<ReasoningEffort>,
 94    },
 95}
 96
 97impl Model {
 98    pub const fn default_fast() -> Self {
 99        // TODO: Replace with FiveMini since all other models are deprecated
100        Self::FourPointOneMini
101    }
102
103    pub fn from_id(id: &str) -> Result<Self> {
104        match id {
105            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
106            "gpt-4" => Ok(Self::Four),
107            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
108            "gpt-4o" => Ok(Self::FourOmni),
109            "gpt-4o-mini" => Ok(Self::FourOmniMini),
110            "gpt-4.1" => Ok(Self::FourPointOne),
111            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
112            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
113            "o1" => Ok(Self::O1),
114            "o3-mini" => Ok(Self::O3Mini),
115            "o3" => Ok(Self::O3),
116            "o4-mini" => Ok(Self::O4Mini),
117            "gpt-5" => Ok(Self::Five),
118            "gpt-5-mini" => Ok(Self::FiveMini),
119            "gpt-5-nano" => Ok(Self::FiveNano),
120            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
121        }
122    }
123
124    pub fn id(&self) -> &str {
125        match self {
126            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
127            Self::Four => "gpt-4",
128            Self::FourTurbo => "gpt-4-turbo",
129            Self::FourOmni => "gpt-4o",
130            Self::FourOmniMini => "gpt-4o-mini",
131            Self::FourPointOne => "gpt-4.1",
132            Self::FourPointOneMini => "gpt-4.1-mini",
133            Self::FourPointOneNano => "gpt-4.1-nano",
134            Self::O1 => "o1",
135            Self::O3Mini => "o3-mini",
136            Self::O3 => "o3",
137            Self::O4Mini => "o4-mini",
138            Self::Five => "gpt-5",
139            Self::FiveMini => "gpt-5-mini",
140            Self::FiveNano => "gpt-5-nano",
141            Self::Custom { name, .. } => name,
142        }
143    }
144
145    pub fn display_name(&self) -> &str {
146        match self {
147            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
148            Self::Four => "gpt-4",
149            Self::FourTurbo => "gpt-4-turbo",
150            Self::FourOmni => "gpt-4o",
151            Self::FourOmniMini => "gpt-4o-mini",
152            Self::FourPointOne => "gpt-4.1",
153            Self::FourPointOneMini => "gpt-4.1-mini",
154            Self::FourPointOneNano => "gpt-4.1-nano",
155            Self::O1 => "o1",
156            Self::O3Mini => "o3-mini",
157            Self::O3 => "o3",
158            Self::O4Mini => "o4-mini",
159            Self::Five => "gpt-5",
160            Self::FiveMini => "gpt-5-mini",
161            Self::FiveNano => "gpt-5-nano",
162            Self::Custom {
163                name, display_name, ..
164            } => display_name.as_ref().unwrap_or(name),
165        }
166    }
167
168    pub const fn max_token_count(&self) -> u64 {
169        match self {
170            Self::ThreePointFiveTurbo => 16_385,
171            Self::Four => 8_192,
172            Self::FourTurbo => 128_000,
173            Self::FourOmni => 128_000,
174            Self::FourOmniMini => 128_000,
175            Self::FourPointOne => 1_047_576,
176            Self::FourPointOneMini => 1_047_576,
177            Self::FourPointOneNano => 1_047_576,
178            Self::O1 => 200_000,
179            Self::O3Mini => 200_000,
180            Self::O3 => 200_000,
181            Self::O4Mini => 200_000,
182            Self::Five => 272_000,
183            Self::FiveMini => 272_000,
184            Self::FiveNano => 272_000,
185            Self::Custom { max_tokens, .. } => *max_tokens,
186        }
187    }
188
189    pub const fn max_output_tokens(&self) -> Option<u64> {
190        match self {
191            Self::Custom {
192                max_output_tokens, ..
193            } => *max_output_tokens,
194            Self::ThreePointFiveTurbo => Some(4_096),
195            Self::Four => Some(8_192),
196            Self::FourTurbo => Some(4_096),
197            Self::FourOmni => Some(16_384),
198            Self::FourOmniMini => Some(16_384),
199            Self::FourPointOne => Some(32_768),
200            Self::FourPointOneMini => Some(32_768),
201            Self::FourPointOneNano => Some(32_768),
202            Self::O1 => Some(100_000),
203            Self::O3Mini => Some(100_000),
204            Self::O3 => Some(100_000),
205            Self::O4Mini => Some(100_000),
206            Self::Five => Some(128_000),
207            Self::FiveMini => Some(128_000),
208            Self::FiveNano => Some(128_000),
209        }
210    }
211
212    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
213        match self {
214            Self::Custom {
215                reasoning_effort, ..
216            } => reasoning_effort.to_owned(),
217            _ => None,
218        }
219    }
220
221    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
222    ///
223    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
224    pub const fn supports_parallel_tool_calls(&self) -> bool {
225        match self {
226            Self::ThreePointFiveTurbo
227            | Self::Four
228            | Self::FourTurbo
229            | Self::FourOmni
230            | Self::FourOmniMini
231            | Self::FourPointOne
232            | Self::FourPointOneMini
233            | Self::FourPointOneNano
234            | Self::Five
235            | Self::FiveMini
236            | Self::FiveNano => true,
237            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
238        }
239    }
240
241    /// Returns whether the given model supports the `prompt_cache_key` parameter.
242    ///
243    /// If the model does not support the parameter, do not pass it up.
244    pub const fn supports_prompt_cache_key(&self) -> bool {
245        true
246    }
247}
248
249#[derive(Debug, Serialize, Deserialize)]
250pub struct Request {
251    pub model: String,
252    pub messages: Vec<RequestMessage>,
253    pub stream: bool,
254    #[serde(default, skip_serializing_if = "Option::is_none")]
255    pub max_completion_tokens: Option<u64>,
256    #[serde(default, skip_serializing_if = "Vec::is_empty")]
257    pub stop: Vec<String>,
258    pub temperature: f32,
259    #[serde(default, skip_serializing_if = "Option::is_none")]
260    pub tool_choice: Option<ToolChoice>,
261    /// Whether to enable parallel function calling during tool use.
262    #[serde(default, skip_serializing_if = "Option::is_none")]
263    pub parallel_tool_calls: Option<bool>,
264    #[serde(default, skip_serializing_if = "Vec::is_empty")]
265    pub tools: Vec<ToolDefinition>,
266    #[serde(default, skip_serializing_if = "Option::is_none")]
267    pub prompt_cache_key: Option<String>,
268    #[serde(default, skip_serializing_if = "Option::is_none")]
269    pub reasoning_effort: Option<ReasoningEffort>,
270}
271
272#[derive(Debug, Serialize, Deserialize)]
273#[serde(rename_all = "lowercase")]
274pub enum ToolChoice {
275    Auto,
276    Required,
277    None,
278    #[serde(untagged)]
279    Other(ToolDefinition),
280}
281
282#[derive(Clone, Deserialize, Serialize, Debug)]
283#[serde(tag = "type", rename_all = "snake_case")]
284pub enum ToolDefinition {
285    #[allow(dead_code)]
286    Function { function: FunctionDefinition },
287}
288
289#[derive(Clone, Debug, Serialize, Deserialize)]
290pub struct FunctionDefinition {
291    pub name: String,
292    pub description: Option<String>,
293    pub parameters: Option<Value>,
294}
295
296#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
297#[serde(tag = "role", rename_all = "lowercase")]
298pub enum RequestMessage {
299    Assistant {
300        content: Option<MessageContent>,
301        #[serde(default, skip_serializing_if = "Vec::is_empty")]
302        tool_calls: Vec<ToolCall>,
303    },
304    User {
305        content: MessageContent,
306    },
307    System {
308        content: MessageContent,
309    },
310    Tool {
311        content: MessageContent,
312        tool_call_id: String,
313    },
314}
315
316#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
317#[serde(untagged)]
318pub enum MessageContent {
319    Plain(String),
320    Multipart(Vec<MessagePart>),
321}
322
323impl MessageContent {
324    pub const fn empty() -> Self {
325        MessageContent::Multipart(vec![])
326    }
327
328    pub fn push_part(&mut self, part: MessagePart) {
329        match self {
330            MessageContent::Plain(text) => {
331                *self =
332                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
333            }
334            MessageContent::Multipart(parts) if parts.is_empty() => match part {
335                MessagePart::Text { text } => *self = MessageContent::Plain(text),
336                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
337            },
338            MessageContent::Multipart(parts) => parts.push(part),
339        }
340    }
341}
342
343impl From<Vec<MessagePart>> for MessageContent {
344    fn from(mut parts: Vec<MessagePart>) -> Self {
345        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
346            MessageContent::Plain(std::mem::take(text))
347        } else {
348            MessageContent::Multipart(parts)
349        }
350    }
351}
352
353#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
354#[serde(tag = "type")]
355pub enum MessagePart {
356    #[serde(rename = "text")]
357    Text { text: String },
358    #[serde(rename = "image_url")]
359    Image { image_url: ImageUrl },
360}
361
362#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
363pub struct ImageUrl {
364    pub url: String,
365    #[serde(skip_serializing_if = "Option::is_none")]
366    pub detail: Option<String>,
367}
368
369#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
370pub struct ToolCall {
371    pub id: String,
372    #[serde(flatten)]
373    pub content: ToolCallContent,
374}
375
376#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
377#[serde(tag = "type", rename_all = "lowercase")]
378pub enum ToolCallContent {
379    Function { function: FunctionContent },
380}
381
382#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
383pub struct FunctionContent {
384    pub name: String,
385    pub arguments: String,
386}
387
388#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
389pub struct ResponseMessageDelta {
390    pub role: Option<Role>,
391    pub content: Option<String>,
392    #[serde(default, skip_serializing_if = "is_none_or_empty")]
393    pub tool_calls: Option<Vec<ToolCallChunk>>,
394}
395
396#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
397pub struct ToolCallChunk {
398    pub index: usize,
399    pub id: Option<String>,
400
401    // There is also an optional `type` field that would determine if a
402    // function is there. Sometimes this streams in with the `function` before
403    // it streams in the `type`
404    pub function: Option<FunctionChunk>,
405}
406
407#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
408pub struct FunctionChunk {
409    pub name: Option<String>,
410    pub arguments: Option<String>,
411}
412
413#[derive(Serialize, Deserialize, Debug)]
414pub struct Usage {
415    pub prompt_tokens: u64,
416    pub completion_tokens: u64,
417    pub total_tokens: u64,
418}
419
420#[derive(Serialize, Deserialize, Debug)]
421pub struct ChoiceDelta {
422    pub index: u32,
423    pub delta: ResponseMessageDelta,
424    pub finish_reason: Option<String>,
425}
426
427#[derive(Serialize, Deserialize, Debug)]
428pub struct OpenAiError {
429    message: String,
430}
431
432#[derive(Serialize, Deserialize, Debug)]
433#[serde(untagged)]
434pub enum ResponseStreamResult {
435    Ok(ResponseStreamEvent),
436    Err { error: OpenAiError },
437}
438
439#[derive(Serialize, Deserialize, Debug)]
440pub struct ResponseStreamEvent {
441    pub choices: Vec<ChoiceDelta>,
442    pub usage: Option<Usage>,
443}
444
445pub async fn stream_completion(
446    client: &dyn HttpClient,
447    api_url: &str,
448    api_key: &str,
449    request: Request,
450) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
451    let uri = format!("{api_url}/chat/completions");
452    let request_builder = HttpRequest::builder()
453        .method(Method::POST)
454        .uri(uri)
455        .header("Content-Type", "application/json")
456        .header("Authorization", format!("Bearer {}", api_key.trim()));
457
458    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
459    let mut response = client.send(request).await?;
460    if response.status().is_success() {
461        let reader = BufReader::new(response.into_body());
462        Ok(reader
463            .lines()
464            .filter_map(|line| async move {
465                match line {
466                    Ok(line) => {
467                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
468                        if line == "[DONE]" {
469                            None
470                        } else {
471                            match serde_json::from_str(line) {
472                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
473                                Ok(ResponseStreamResult::Err { error }) => {
474                                    Some(Err(anyhow!(error.message)))
475                                }
476                                Err(error) => {
477                                    log::error!(
478                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
479                                        Response: `{}`",
480                                        error,
481                                        line,
482                                    );
483                                    Some(Err(anyhow!(error)))
484                                }
485                            }
486                        }
487                    }
488                    Err(error) => Some(Err(anyhow!(error))),
489                }
490            })
491            .boxed())
492    } else {
493        let mut body = String::new();
494        response.body_mut().read_to_string(&mut body).await?;
495
496        #[derive(Deserialize)]
497        struct OpenAiResponse {
498            error: OpenAiError,
499        }
500
501        match serde_json::from_str::<OpenAiResponse>(&body) {
502            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
503                "API request to {} failed: {}",
504                api_url,
505                response.error.message,
506            )),
507
508            _ => anyhow::bail!(
509                "API request to {} failed with status {}: {}",
510                api_url,
511                response.status(),
512                body,
513            ),
514        }
515    }
516}
517
518#[derive(Copy, Clone, Serialize, Deserialize)]
519pub enum OpenAiEmbeddingModel {
520    #[serde(rename = "text-embedding-3-small")]
521    TextEmbedding3Small,
522    #[serde(rename = "text-embedding-3-large")]
523    TextEmbedding3Large,
524}
525
526#[derive(Serialize)]
527struct OpenAiEmbeddingRequest<'a> {
528    model: OpenAiEmbeddingModel,
529    input: Vec<&'a str>,
530}
531
532#[derive(Deserialize)]
533pub struct OpenAiEmbeddingResponse {
534    pub data: Vec<OpenAiEmbedding>,
535}
536
537#[derive(Deserialize)]
538pub struct OpenAiEmbedding {
539    pub embedding: Vec<f32>,
540}
541
542pub fn embed<'a>(
543    client: &dyn HttpClient,
544    api_url: &str,
545    api_key: &str,
546    model: OpenAiEmbeddingModel,
547    texts: impl IntoIterator<Item = &'a str>,
548) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
549    let uri = format!("{api_url}/embeddings");
550
551    let request = OpenAiEmbeddingRequest {
552        model,
553        input: texts.into_iter().collect(),
554    };
555    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
556    let request = HttpRequest::builder()
557        .method(Method::POST)
558        .uri(uri)
559        .header("Content-Type", "application/json")
560        .header("Authorization", format!("Bearer {}", api_key.trim()))
561        .body(body)
562        .map(|request| client.send(request));
563
564    async move {
565        let mut response = request?.await?;
566        let mut body = String::new();
567        response.body_mut().read_to_string(&mut body).await?;
568
569        anyhow::ensure!(
570            response.status().is_success(),
571            "error during embedding, status: {:?}, body: {:?}",
572            response.status(),
573            body
574        );
575        let response: OpenAiEmbeddingResponse =
576            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
577        Ok(response)
578    }
579}