open_ai.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{convert::TryFrom, future::Future};
  7use strum::EnumIter;
  8
  9pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 10
 11fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 12    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 13}
 14
 15#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 16#[serde(rename_all = "lowercase")]
 17pub enum Role {
 18    User,
 19    Assistant,
 20    System,
 21    Tool,
 22}
 23
 24impl TryFrom<String> for Role {
 25    type Error = anyhow::Error;
 26
 27    fn try_from(value: String) -> Result<Self> {
 28        match value.as_str() {
 29            "user" => Ok(Self::User),
 30            "assistant" => Ok(Self::Assistant),
 31            "system" => Ok(Self::System),
 32            "tool" => Ok(Self::Tool),
 33            _ => anyhow::bail!("invalid role '{value}'"),
 34        }
 35    }
 36}
 37
 38impl From<Role> for String {
 39    fn from(val: Role) -> Self {
 40        match val {
 41            Role::User => "user".to_owned(),
 42            Role::Assistant => "assistant".to_owned(),
 43            Role::System => "system".to_owned(),
 44            Role::Tool => "tool".to_owned(),
 45        }
 46    }
 47}
 48
 49#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 50#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 51pub enum Model {
 52    #[serde(rename = "gpt-3.5-turbo")]
 53    ThreePointFiveTurbo,
 54    #[serde(rename = "gpt-4")]
 55    Four,
 56    #[serde(rename = "gpt-4-turbo")]
 57    FourTurbo,
 58    #[serde(rename = "gpt-4o")]
 59    #[default]
 60    FourOmni,
 61    #[serde(rename = "gpt-4o-mini")]
 62    FourOmniMini,
 63    #[serde(rename = "gpt-4.1")]
 64    FourPointOne,
 65    #[serde(rename = "gpt-4.1-mini")]
 66    FourPointOneMini,
 67    #[serde(rename = "gpt-4.1-nano")]
 68    FourPointOneNano,
 69    #[serde(rename = "o1")]
 70    O1,
 71    #[serde(rename = "o3-mini")]
 72    O3Mini,
 73    #[serde(rename = "o3")]
 74    O3,
 75    #[serde(rename = "o4-mini")]
 76    O4Mini,
 77
 78    #[serde(rename = "custom")]
 79    Custom {
 80        name: String,
 81        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 82        display_name: Option<String>,
 83        max_tokens: u64,
 84        max_output_tokens: Option<u64>,
 85        max_completion_tokens: Option<u64>,
 86    },
 87}
 88
 89impl Model {
 90    pub fn default_fast() -> Self {
 91        Self::FourPointOneMini
 92    }
 93
 94    pub fn from_id(id: &str) -> Result<Self> {
 95        match id {
 96            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 97            "gpt-4" => Ok(Self::Four),
 98            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 99            "gpt-4o" => Ok(Self::FourOmni),
100            "gpt-4o-mini" => Ok(Self::FourOmniMini),
101            "gpt-4.1" => Ok(Self::FourPointOne),
102            "gpt-4.1-mini" => Ok(Self::FourPointOneMini),
103            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
104            "o1" => Ok(Self::O1),
105            "o3-mini" => Ok(Self::O3Mini),
106            "o3" => Ok(Self::O3),
107            "o4-mini" => Ok(Self::O4Mini),
108            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
109        }
110    }
111
112    pub fn id(&self) -> &str {
113        match self {
114            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
115            Self::Four => "gpt-4",
116            Self::FourTurbo => "gpt-4-turbo",
117            Self::FourOmni => "gpt-4o",
118            Self::FourOmniMini => "gpt-4o-mini",
119            Self::FourPointOne => "gpt-4.1",
120            Self::FourPointOneMini => "gpt-4.1-mini",
121            Self::FourPointOneNano => "gpt-4.1-nano",
122            Self::O1 => "o1",
123            Self::O3Mini => "o3-mini",
124            Self::O3 => "o3",
125            Self::O4Mini => "o4-mini",
126            Self::Custom { name, .. } => name,
127        }
128    }
129
130    pub fn display_name(&self) -> &str {
131        match self {
132            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
133            Self::Four => "gpt-4",
134            Self::FourTurbo => "gpt-4-turbo",
135            Self::FourOmni => "gpt-4o",
136            Self::FourOmniMini => "gpt-4o-mini",
137            Self::FourPointOne => "gpt-4.1",
138            Self::FourPointOneMini => "gpt-4.1-mini",
139            Self::FourPointOneNano => "gpt-4.1-nano",
140            Self::O1 => "o1",
141            Self::O3Mini => "o3-mini",
142            Self::O3 => "o3",
143            Self::O4Mini => "o4-mini",
144            Self::Custom {
145                name, display_name, ..
146            } => display_name.as_ref().unwrap_or(name),
147        }
148    }
149
150    pub fn max_token_count(&self) -> u64 {
151        match self {
152            Self::ThreePointFiveTurbo => 16_385,
153            Self::Four => 8_192,
154            Self::FourTurbo => 128_000,
155            Self::FourOmni => 128_000,
156            Self::FourOmniMini => 128_000,
157            Self::FourPointOne => 1_047_576,
158            Self::FourPointOneMini => 1_047_576,
159            Self::FourPointOneNano => 1_047_576,
160            Self::O1 => 200_000,
161            Self::O3Mini => 200_000,
162            Self::O3 => 200_000,
163            Self::O4Mini => 200_000,
164            Self::Custom { max_tokens, .. } => *max_tokens,
165        }
166    }
167
168    pub fn max_output_tokens(&self) -> Option<u64> {
169        match self {
170            Self::Custom {
171                max_output_tokens, ..
172            } => *max_output_tokens,
173            Self::ThreePointFiveTurbo => Some(4_096),
174            Self::Four => Some(8_192),
175            Self::FourTurbo => Some(4_096),
176            Self::FourOmni => Some(16_384),
177            Self::FourOmniMini => Some(16_384),
178            Self::FourPointOne => Some(32_768),
179            Self::FourPointOneMini => Some(32_768),
180            Self::FourPointOneNano => Some(32_768),
181            Self::O1 => Some(100_000),
182            Self::O3Mini => Some(100_000),
183            Self::O3 => Some(100_000),
184            Self::O4Mini => Some(100_000),
185        }
186    }
187
188    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
189    ///
190    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
191    pub fn supports_parallel_tool_calls(&self) -> bool {
192        match self {
193            Self::ThreePointFiveTurbo
194            | Self::Four
195            | Self::FourTurbo
196            | Self::FourOmni
197            | Self::FourOmniMini
198            | Self::FourPointOne
199            | Self::FourPointOneMini
200            | Self::FourPointOneNano => true,
201            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
202        }
203    }
204}
205
206#[derive(Debug, Serialize, Deserialize)]
207pub struct Request {
208    pub model: String,
209    pub messages: Vec<RequestMessage>,
210    pub stream: bool,
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub max_completion_tokens: Option<u64>,
213    #[serde(default, skip_serializing_if = "Vec::is_empty")]
214    pub stop: Vec<String>,
215    pub temperature: f32,
216    #[serde(default, skip_serializing_if = "Option::is_none")]
217    pub tool_choice: Option<ToolChoice>,
218    /// Whether to enable parallel function calling during tool use.
219    #[serde(default, skip_serializing_if = "Option::is_none")]
220    pub parallel_tool_calls: Option<bool>,
221    #[serde(default, skip_serializing_if = "Vec::is_empty")]
222    pub tools: Vec<ToolDefinition>,
223}
224
225#[derive(Debug, Serialize, Deserialize)]
226#[serde(untagged)]
227pub enum ToolChoice {
228    Auto,
229    Required,
230    None,
231    Other(ToolDefinition),
232}
233
234#[derive(Clone, Deserialize, Serialize, Debug)]
235#[serde(tag = "type", rename_all = "snake_case")]
236pub enum ToolDefinition {
237    #[allow(dead_code)]
238    Function { function: FunctionDefinition },
239}
240
241#[derive(Clone, Debug, Serialize, Deserialize)]
242pub struct FunctionDefinition {
243    pub name: String,
244    pub description: Option<String>,
245    pub parameters: Option<Value>,
246}
247
248#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
249#[serde(tag = "role", rename_all = "lowercase")]
250pub enum RequestMessage {
251    Assistant {
252        content: Option<MessageContent>,
253        #[serde(default, skip_serializing_if = "Vec::is_empty")]
254        tool_calls: Vec<ToolCall>,
255    },
256    User {
257        content: MessageContent,
258    },
259    System {
260        content: MessageContent,
261    },
262    Tool {
263        content: MessageContent,
264        tool_call_id: String,
265    },
266}
267
268#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
269#[serde(untagged)]
270pub enum MessageContent {
271    Plain(String),
272    Multipart(Vec<MessagePart>),
273}
274
275impl MessageContent {
276    pub fn empty() -> Self {
277        MessageContent::Multipart(vec![])
278    }
279
280    pub fn push_part(&mut self, part: MessagePart) {
281        match self {
282            MessageContent::Plain(text) => {
283                *self =
284                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
285            }
286            MessageContent::Multipart(parts) if parts.is_empty() => match part {
287                MessagePart::Text { text } => *self = MessageContent::Plain(text),
288                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
289            },
290            MessageContent::Multipart(parts) => parts.push(part),
291        }
292    }
293}
294
295impl From<Vec<MessagePart>> for MessageContent {
296    fn from(mut parts: Vec<MessagePart>) -> Self {
297        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
298            MessageContent::Plain(std::mem::take(text))
299        } else {
300            MessageContent::Multipart(parts)
301        }
302    }
303}
304
305#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
306#[serde(tag = "type")]
307pub enum MessagePart {
308    #[serde(rename = "text")]
309    Text { text: String },
310    #[serde(rename = "image_url")]
311    Image { image_url: ImageUrl },
312}
313
314#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
315pub struct ImageUrl {
316    pub url: String,
317    #[serde(skip_serializing_if = "Option::is_none")]
318    pub detail: Option<String>,
319}
320
321#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
322pub struct ToolCall {
323    pub id: String,
324    #[serde(flatten)]
325    pub content: ToolCallContent,
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329#[serde(tag = "type", rename_all = "lowercase")]
330pub enum ToolCallContent {
331    Function { function: FunctionContent },
332}
333
334#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
335pub struct FunctionContent {
336    pub name: String,
337    pub arguments: String,
338}
339
340#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
341pub struct ResponseMessageDelta {
342    pub role: Option<Role>,
343    pub content: Option<String>,
344    #[serde(default, skip_serializing_if = "is_none_or_empty")]
345    pub tool_calls: Option<Vec<ToolCallChunk>>,
346}
347
348#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
349pub struct ToolCallChunk {
350    pub index: usize,
351    pub id: Option<String>,
352
353    // There is also an optional `type` field that would determine if a
354    // function is there. Sometimes this streams in with the `function` before
355    // it streams in the `type`
356    pub function: Option<FunctionChunk>,
357}
358
359#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
360pub struct FunctionChunk {
361    pub name: Option<String>,
362    pub arguments: Option<String>,
363}
364
365#[derive(Serialize, Deserialize, Debug)]
366pub struct Usage {
367    pub prompt_tokens: u64,
368    pub completion_tokens: u64,
369    pub total_tokens: u64,
370}
371
372#[derive(Serialize, Deserialize, Debug)]
373pub struct ChoiceDelta {
374    pub index: u32,
375    pub delta: ResponseMessageDelta,
376    pub finish_reason: Option<String>,
377}
378
379#[derive(Serialize, Deserialize, Debug)]
380#[serde(untagged)]
381pub enum ResponseStreamResult {
382    Ok(ResponseStreamEvent),
383    Err { error: String },
384}
385
386#[derive(Serialize, Deserialize, Debug)]
387pub struct ResponseStreamEvent {
388    pub model: String,
389    pub choices: Vec<ChoiceDelta>,
390    pub usage: Option<Usage>,
391}
392
393pub async fn stream_completion(
394    client: &dyn HttpClient,
395    api_url: &str,
396    api_key: &str,
397    request: Request,
398) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
399    let uri = format!("{api_url}/chat/completions");
400    let request_builder = HttpRequest::builder()
401        .method(Method::POST)
402        .uri(uri)
403        .header("Content-Type", "application/json")
404        .header("Authorization", format!("Bearer {}", api_key));
405
406    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
407    let mut response = client.send(request).await?;
408    if response.status().is_success() {
409        let reader = BufReader::new(response.into_body());
410        Ok(reader
411            .lines()
412            .filter_map(|line| async move {
413                match line {
414                    Ok(line) => {
415                        let line = line.strip_prefix("data: ")?;
416                        if line == "[DONE]" {
417                            None
418                        } else {
419                            match serde_json::from_str(line) {
420                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
421                                Ok(ResponseStreamResult::Err { error }) => {
422                                    Some(Err(anyhow!(error)))
423                                }
424                                Err(error) => Some(Err(anyhow!(error))),
425                            }
426                        }
427                    }
428                    Err(error) => Some(Err(anyhow!(error))),
429                }
430            })
431            .boxed())
432    } else {
433        let mut body = String::new();
434        response.body_mut().read_to_string(&mut body).await?;
435
436        #[derive(Deserialize)]
437        struct OpenAiResponse {
438            error: OpenAiError,
439        }
440
441        #[derive(Deserialize)]
442        struct OpenAiError {
443            message: String,
444        }
445
446        match serde_json::from_str::<OpenAiResponse>(&body) {
447            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
448                "API request to {} failed: {}",
449                api_url,
450                response.error.message,
451            )),
452
453            _ => anyhow::bail!(
454                "API request to {} failed with status {}: {}",
455                api_url,
456                response.status(),
457                body,
458            ),
459        }
460    }
461}
462
463#[derive(Copy, Clone, Serialize, Deserialize)]
464pub enum OpenAiEmbeddingModel {
465    #[serde(rename = "text-embedding-3-small")]
466    TextEmbedding3Small,
467    #[serde(rename = "text-embedding-3-large")]
468    TextEmbedding3Large,
469}
470
471#[derive(Serialize)]
472struct OpenAiEmbeddingRequest<'a> {
473    model: OpenAiEmbeddingModel,
474    input: Vec<&'a str>,
475}
476
477#[derive(Deserialize)]
478pub struct OpenAiEmbeddingResponse {
479    pub data: Vec<OpenAiEmbedding>,
480}
481
482#[derive(Deserialize)]
483pub struct OpenAiEmbedding {
484    pub embedding: Vec<f32>,
485}
486
487pub fn embed<'a>(
488    client: &dyn HttpClient,
489    api_url: &str,
490    api_key: &str,
491    model: OpenAiEmbeddingModel,
492    texts: impl IntoIterator<Item = &'a str>,
493) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
494    let uri = format!("{api_url}/embeddings");
495
496    let request = OpenAiEmbeddingRequest {
497        model,
498        input: texts.into_iter().collect(),
499    };
500    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
501    let request = HttpRequest::builder()
502        .method(Method::POST)
503        .uri(uri)
504        .header("Content-Type", "application/json")
505        .header("Authorization", format!("Bearer {}", api_key))
506        .body(body)
507        .map(|request| client.send(request));
508
509    async move {
510        let mut response = request?.await?;
511        let mut body = String::new();
512        response.body_mut().read_to_string(&mut body).await?;
513
514        anyhow::ensure!(
515            response.status().is_success(),
516            "error during embedding, status: {:?}, body: {:?}",
517            response.status(),
518            body
519        );
520        let response: OpenAiEmbeddingResponse =
521            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
522        Ok(response)
523    }
524}