open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use futures::{
  5    AsyncBufReadExt, AsyncReadExt, StreamExt,
  6    io::BufReader,
  7    stream::{self, BoxStream},
  8};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12use std::{
 13    convert::TryFrom,
 14    future::{self, Future},
 15};
 16use strum::EnumIter;
 17
 18pub use supported_countries::*;
 19
 20pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 21
 22fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 23    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 24}
 25
 26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 27#[serde(rename_all = "lowercase")]
 28pub enum Role {
 29    User,
 30    Assistant,
 31    System,
 32    Tool,
 33}
 34
 35impl TryFrom<String> for Role {
 36    type Error = anyhow::Error;
 37
 38    fn try_from(value: String) -> Result<Self> {
 39        match value.as_str() {
 40            "user" => Ok(Self::User),
 41            "assistant" => Ok(Self::Assistant),
 42            "system" => Ok(Self::System),
 43            "tool" => Ok(Self::Tool),
 44            _ => Err(anyhow!("invalid role '{value}'")),
 45        }
 46    }
 47}
 48
 49impl From<Role> for String {
 50    fn from(val: Role) -> Self {
 51        match val {
 52            Role::User => "user".to_owned(),
 53            Role::Assistant => "assistant".to_owned(),
 54            Role::System => "system".to_owned(),
 55            Role::Tool => "tool".to_owned(),
 56        }
 57    }
 58}
 59
 60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 62pub enum Model {
 63    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
 64    ThreePointFiveTurbo,
 65    #[serde(rename = "gpt-4", alias = "gpt-4")]
 66    Four,
 67    #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
 68    FourTurbo,
 69    #[serde(rename = "gpt-4o", alias = "gpt-4o")]
 70    #[default]
 71    FourOmni,
 72    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
 73    FourOmniMini,
 74    #[serde(rename = "o1", alias = "o1")]
 75    O1,
 76    #[serde(rename = "o1-preview", alias = "o1-preview")]
 77    O1Preview,
 78    #[serde(rename = "o1-mini", alias = "o1-mini")]
 79    O1Mini,
 80    #[serde(rename = "o3-mini", alias = "o3-mini")]
 81    O3Mini,
 82
 83    #[serde(rename = "custom")]
 84    Custom {
 85        name: String,
 86        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 87        display_name: Option<String>,
 88        max_tokens: usize,
 89        max_output_tokens: Option<u32>,
 90        max_completion_tokens: Option<u32>,
 91    },
 92}
 93
 94impl Model {
 95    pub fn from_id(id: &str) -> Result<Self> {
 96        match id {
 97            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 98            "gpt-4" => Ok(Self::Four),
 99            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
100            "gpt-4o" => Ok(Self::FourOmni),
101            "gpt-4o-mini" => Ok(Self::FourOmniMini),
102            "o1" => Ok(Self::O1),
103            "o1-preview" => Ok(Self::O1Preview),
104            "o1-mini" => Ok(Self::O1Mini),
105            "o3-mini" => Ok(Self::O3Mini),
106            _ => Err(anyhow!("invalid model id")),
107        }
108    }
109
110    pub fn id(&self) -> &str {
111        match self {
112            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
113            Self::Four => "gpt-4",
114            Self::FourTurbo => "gpt-4-turbo",
115            Self::FourOmni => "gpt-4o",
116            Self::FourOmniMini => "gpt-4o-mini",
117            Self::O1 => "o1",
118            Self::O1Preview => "o1-preview",
119            Self::O1Mini => "o1-mini",
120            Self::O3Mini => "o3-mini",
121            Self::Custom { name, .. } => name,
122        }
123    }
124
125    pub fn display_name(&self) -> &str {
126        match self {
127            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
128            Self::Four => "gpt-4",
129            Self::FourTurbo => "gpt-4-turbo",
130            Self::FourOmni => "gpt-4o",
131            Self::FourOmniMini => "gpt-4o-mini",
132            Self::O1 => "o1",
133            Self::O1Preview => "o1-preview",
134            Self::O1Mini => "o1-mini",
135            Self::O3Mini => "o3-mini",
136            Self::Custom {
137                name, display_name, ..
138            } => display_name.as_ref().unwrap_or(name),
139        }
140    }
141
142    pub fn max_token_count(&self) -> usize {
143        match self {
144            Self::ThreePointFiveTurbo => 16_385,
145            Self::Four => 8_192,
146            Self::FourTurbo => 128_000,
147            Self::FourOmni => 128_000,
148            Self::FourOmniMini => 128_000,
149            Self::O1 => 200_000,
150            Self::O1Preview => 128_000,
151            Self::O1Mini => 128_000,
152            Self::O3Mini => 200_000,
153            Self::Custom { max_tokens, .. } => *max_tokens,
154        }
155    }
156
157    pub fn max_output_tokens(&self) -> Option<u32> {
158        match self {
159            Self::Custom {
160                max_output_tokens, ..
161            } => *max_output_tokens,
162            _ => None,
163        }
164    }
165
166    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
167    ///
168    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
169    pub fn supports_parallel_tool_calls(&self) -> bool {
170        match self {
171            Self::ThreePointFiveTurbo
172            | Self::Four
173            | Self::FourTurbo
174            | Self::FourOmni
175            | Self::FourOmniMini
176            | Self::O1
177            | Self::O1Preview
178            | Self::O1Mini => true,
179            _ => false,
180        }
181    }
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185pub struct Request {
186    pub model: String,
187    pub messages: Vec<RequestMessage>,
188    pub stream: bool,
189    #[serde(default, skip_serializing_if = "Option::is_none")]
190    pub max_tokens: Option<u32>,
191    #[serde(default, skip_serializing_if = "Vec::is_empty")]
192    pub stop: Vec<String>,
193    pub temperature: f32,
194    #[serde(default, skip_serializing_if = "Option::is_none")]
195    pub tool_choice: Option<ToolChoice>,
196    /// Whether to enable parallel function calling during tool use.
197    #[serde(default, skip_serializing_if = "Option::is_none")]
198    pub parallel_tool_calls: Option<bool>,
199    #[serde(default, skip_serializing_if = "Vec::is_empty")]
200    pub tools: Vec<ToolDefinition>,
201}
202
203#[derive(Debug, Serialize, Deserialize)]
204pub struct CompletionRequest {
205    pub model: String,
206    pub prompt: String,
207    pub max_tokens: u32,
208    pub temperature: f32,
209    #[serde(default, skip_serializing_if = "Option::is_none")]
210    pub prediction: Option<Prediction>,
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub rewrite_speculation: Option<bool>,
213}
214
215#[derive(Clone, Deserialize, Serialize, Debug)]
216#[serde(tag = "type", rename_all = "snake_case")]
217pub enum Prediction {
218    Content { content: String },
219}
220
221#[derive(Debug, Serialize, Deserialize)]
222#[serde(untagged)]
223pub enum ToolChoice {
224    Auto,
225    Required,
226    None,
227    Other(ToolDefinition),
228}
229
230#[derive(Clone, Deserialize, Serialize, Debug)]
231#[serde(tag = "type", rename_all = "snake_case")]
232pub enum ToolDefinition {
233    #[allow(dead_code)]
234    Function { function: FunctionDefinition },
235}
236
237#[derive(Clone, Debug, Serialize, Deserialize)]
238pub struct FunctionDefinition {
239    pub name: String,
240    pub description: Option<String>,
241    pub parameters: Option<Value>,
242}
243
244#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
245#[serde(tag = "role", rename_all = "lowercase")]
246pub enum RequestMessage {
247    Assistant {
248        content: Option<String>,
249        #[serde(default, skip_serializing_if = "Vec::is_empty")]
250        tool_calls: Vec<ToolCall>,
251    },
252    User {
253        content: String,
254    },
255    System {
256        content: String,
257    },
258    Tool {
259        content: String,
260        tool_call_id: String,
261    },
262}
263
264#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
265pub struct ToolCall {
266    pub id: String,
267    #[serde(flatten)]
268    pub content: ToolCallContent,
269}
270
271#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
272#[serde(tag = "type", rename_all = "lowercase")]
273pub enum ToolCallContent {
274    Function { function: FunctionContent },
275}
276
277#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
278pub struct FunctionContent {
279    pub name: String,
280    pub arguments: String,
281}
282
283#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
284pub struct ResponseMessageDelta {
285    pub role: Option<Role>,
286    pub content: Option<String>,
287    #[serde(default, skip_serializing_if = "is_none_or_empty")]
288    pub tool_calls: Option<Vec<ToolCallChunk>>,
289}
290
291#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
292pub struct ToolCallChunk {
293    pub index: usize,
294    pub id: Option<String>,
295
296    // There is also an optional `type` field that would determine if a
297    // function is there. Sometimes this streams in with the `function` before
298    // it streams in the `type`
299    pub function: Option<FunctionChunk>,
300}
301
302#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
303pub struct FunctionChunk {
304    pub name: Option<String>,
305    pub arguments: Option<String>,
306}
307
308#[derive(Serialize, Deserialize, Debug)]
309pub struct Usage {
310    pub prompt_tokens: u32,
311    pub completion_tokens: u32,
312    pub total_tokens: u32,
313}
314
315#[derive(Serialize, Deserialize, Debug)]
316pub struct ChoiceDelta {
317    pub index: u32,
318    pub delta: ResponseMessageDelta,
319    pub finish_reason: Option<String>,
320}
321
322#[derive(Serialize, Deserialize, Debug)]
323#[serde(untagged)]
324pub enum ResponseStreamResult {
325    Ok(ResponseStreamEvent),
326    Err { error: String },
327}
328
329#[derive(Serialize, Deserialize, Debug)]
330pub struct ResponseStreamEvent {
331    pub created: u32,
332    pub model: String,
333    pub choices: Vec<ChoiceDelta>,
334    pub usage: Option<Usage>,
335}
336
337#[derive(Serialize, Deserialize, Debug)]
338pub struct CompletionResponse {
339    pub id: String,
340    pub object: String,
341    pub created: u64,
342    pub model: String,
343    pub choices: Vec<CompletionChoice>,
344    pub usage: Usage,
345}
346
347#[derive(Serialize, Deserialize, Debug)]
348pub struct CompletionChoice {
349    pub text: String,
350}
351
352#[derive(Serialize, Deserialize, Debug)]
353pub struct Response {
354    pub id: String,
355    pub object: String,
356    pub created: u64,
357    pub model: String,
358    pub choices: Vec<Choice>,
359    pub usage: Usage,
360}
361
362#[derive(Serialize, Deserialize, Debug)]
363pub struct Choice {
364    pub index: u32,
365    pub message: RequestMessage,
366    pub finish_reason: Option<String>,
367}
368
369pub async fn complete(
370    client: &dyn HttpClient,
371    api_url: &str,
372    api_key: &str,
373    request: Request,
374) -> Result<Response> {
375    let uri = format!("{api_url}/chat/completions");
376    let request_builder = HttpRequest::builder()
377        .method(Method::POST)
378        .uri(uri)
379        .header("Content-Type", "application/json")
380        .header("Authorization", format!("Bearer {}", api_key));
381
382    let mut request_body = request;
383    request_body.stream = false;
384
385    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
386    let mut response = client.send(request).await?;
387
388    if response.status().is_success() {
389        let mut body = String::new();
390        response.body_mut().read_to_string(&mut body).await?;
391        let response: Response = serde_json::from_str(&body)?;
392        Ok(response)
393    } else {
394        let mut body = String::new();
395        response.body_mut().read_to_string(&mut body).await?;
396
397        #[derive(Deserialize)]
398        struct OpenAiResponse {
399            error: OpenAiError,
400        }
401
402        #[derive(Deserialize)]
403        struct OpenAiError {
404            message: String,
405        }
406
407        match serde_json::from_str::<OpenAiResponse>(&body) {
408            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
409                "Failed to connect to OpenAI API: {}",
410                response.error.message,
411            )),
412
413            _ => Err(anyhow!(
414                "Failed to connect to OpenAI API: {} {}",
415                response.status(),
416                body,
417            )),
418        }
419    }
420}
421
422pub async fn complete_text(
423    client: &dyn HttpClient,
424    api_url: &str,
425    api_key: &str,
426    request: CompletionRequest,
427) -> Result<CompletionResponse> {
428    let uri = format!("{api_url}/completions");
429    let request_builder = HttpRequest::builder()
430        .method(Method::POST)
431        .uri(uri)
432        .header("Content-Type", "application/json")
433        .header("Authorization", format!("Bearer {}", api_key));
434
435    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
436    let mut response = client.send(request).await?;
437
438    if response.status().is_success() {
439        let mut body = String::new();
440        response.body_mut().read_to_string(&mut body).await?;
441        let response = serde_json::from_str(&body)?;
442        Ok(response)
443    } else {
444        let mut body = String::new();
445        response.body_mut().read_to_string(&mut body).await?;
446
447        #[derive(Deserialize)]
448        struct OpenAiResponse {
449            error: OpenAiError,
450        }
451
452        #[derive(Deserialize)]
453        struct OpenAiError {
454            message: String,
455        }
456
457        match serde_json::from_str::<OpenAiResponse>(&body) {
458            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
459                "Failed to connect to OpenAI API: {}",
460                response.error.message,
461            )),
462
463            _ => Err(anyhow!(
464                "Failed to connect to OpenAI API: {} {}",
465                response.status(),
466                body,
467            )),
468        }
469    }
470}
471
472fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
473    ResponseStreamEvent {
474        created: response.created as u32,
475        model: response.model,
476        choices: response
477            .choices
478            .into_iter()
479            .map(|choice| ChoiceDelta {
480                index: choice.index,
481                delta: ResponseMessageDelta {
482                    role: Some(match choice.message {
483                        RequestMessage::Assistant { .. } => Role::Assistant,
484                        RequestMessage::User { .. } => Role::User,
485                        RequestMessage::System { .. } => Role::System,
486                        RequestMessage::Tool { .. } => Role::Tool,
487                    }),
488                    content: match choice.message {
489                        RequestMessage::Assistant { content, .. } => content,
490                        RequestMessage::User { content } => Some(content),
491                        RequestMessage::System { content } => Some(content),
492                        RequestMessage::Tool { content, .. } => Some(content),
493                    },
494                    tool_calls: None,
495                },
496                finish_reason: choice.finish_reason,
497            })
498            .collect(),
499        usage: Some(response.usage),
500    }
501}
502
503pub async fn stream_completion(
504    client: &dyn HttpClient,
505    api_url: &str,
506    api_key: &str,
507    request: Request,
508) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
509    if request.model.starts_with("o1") {
510        let response = complete(client, api_url, api_key, request).await;
511        let response_stream_event = response.map(adapt_response_to_stream);
512        return Ok(stream::once(future::ready(response_stream_event)).boxed());
513    }
514
515    let uri = format!("{api_url}/chat/completions");
516    let request_builder = HttpRequest::builder()
517        .method(Method::POST)
518        .uri(uri)
519        .header("Content-Type", "application/json")
520        .header("Authorization", format!("Bearer {}", api_key));
521
522    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
523    let mut response = client.send(request).await?;
524    if response.status().is_success() {
525        let reader = BufReader::new(response.into_body());
526        Ok(reader
527            .lines()
528            .filter_map(|line| async move {
529                match line {
530                    Ok(line) => {
531                        let line = line.strip_prefix("data: ")?;
532                        if line == "[DONE]" {
533                            None
534                        } else {
535                            match serde_json::from_str(line) {
536                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
537                                Ok(ResponseStreamResult::Err { error }) => {
538                                    Some(Err(anyhow!(error)))
539                                }
540                                Err(error) => Some(Err(anyhow!(error))),
541                            }
542                        }
543                    }
544                    Err(error) => Some(Err(anyhow!(error))),
545                }
546            })
547            .boxed())
548    } else {
549        let mut body = String::new();
550        response.body_mut().read_to_string(&mut body).await?;
551
552        #[derive(Deserialize)]
553        struct OpenAiResponse {
554            error: OpenAiError,
555        }
556
557        #[derive(Deserialize)]
558        struct OpenAiError {
559            message: String,
560        }
561
562        match serde_json::from_str::<OpenAiResponse>(&body) {
563            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
564                "Failed to connect to OpenAI API: {}",
565                response.error.message,
566            )),
567
568            _ => Err(anyhow!(
569                "Failed to connect to OpenAI API: {} {}",
570                response.status(),
571                body,
572            )),
573        }
574    }
575}
576
577#[derive(Copy, Clone, Serialize, Deserialize)]
578pub enum OpenAiEmbeddingModel {
579    #[serde(rename = "text-embedding-3-small")]
580    TextEmbedding3Small,
581    #[serde(rename = "text-embedding-3-large")]
582    TextEmbedding3Large,
583}
584
585#[derive(Serialize)]
586struct OpenAiEmbeddingRequest<'a> {
587    model: OpenAiEmbeddingModel,
588    input: Vec<&'a str>,
589}
590
591#[derive(Deserialize)]
592pub struct OpenAiEmbeddingResponse {
593    pub data: Vec<OpenAiEmbedding>,
594}
595
596#[derive(Deserialize)]
597pub struct OpenAiEmbedding {
598    pub embedding: Vec<f32>,
599}
600
601pub fn embed<'a>(
602    client: &dyn HttpClient,
603    api_url: &str,
604    api_key: &str,
605    model: OpenAiEmbeddingModel,
606    texts: impl IntoIterator<Item = &'a str>,
607) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
608    let uri = format!("{api_url}/embeddings");
609
610    let request = OpenAiEmbeddingRequest {
611        model,
612        input: texts.into_iter().collect(),
613    };
614    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
615    let request = HttpRequest::builder()
616        .method(Method::POST)
617        .uri(uri)
618        .header("Content-Type", "application/json")
619        .header("Authorization", format!("Bearer {}", api_key))
620        .body(body)
621        .map(|request| client.send(request));
622
623    async move {
624        let mut response = request?.await?;
625        let mut body = String::new();
626        response.body_mut().read_to_string(&mut body).await?;
627
628        if response.status().is_success() {
629            let response: OpenAiEmbeddingResponse =
630                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
631            Ok(response)
632        } else {
633            Err(anyhow!(
634                "error during embedding, status: {:?}, body: {:?}",
635                response.status(),
636                body
637            ))
638        }
639    }
640}