open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use futures::{
  5    io::BufReader,
  6    stream::{self, BoxStream},
  7    AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
  8};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12use std::{
 13    convert::TryFrom,
 14    future::{self, Future},
 15    pin::Pin,
 16};
 17use strum::EnumIter;
 18
 19pub use supported_countries::*;
 20
 21pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 22
 23fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 24    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 25}
 26
 27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 28#[serde(rename_all = "lowercase")]
 29pub enum Role {
 30    User,
 31    Assistant,
 32    System,
 33    Tool,
 34}
 35
 36impl TryFrom<String> for Role {
 37    type Error = anyhow::Error;
 38
 39    fn try_from(value: String) -> Result<Self> {
 40        match value.as_str() {
 41            "user" => Ok(Self::User),
 42            "assistant" => Ok(Self::Assistant),
 43            "system" => Ok(Self::System),
 44            "tool" => Ok(Self::Tool),
 45            _ => Err(anyhow!("invalid role '{value}'")),
 46        }
 47    }
 48}
 49
 50impl From<Role> for String {
 51    fn from(val: Role) -> Self {
 52        match val {
 53            Role::User => "user".to_owned(),
 54            Role::Assistant => "assistant".to_owned(),
 55            Role::System => "system".to_owned(),
 56            Role::Tool => "tool".to_owned(),
 57        }
 58    }
 59}
 60
 61#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 62#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 63pub enum Model {
 64    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
 65    ThreePointFiveTurbo,
 66    #[serde(rename = "gpt-4", alias = "gpt-4")]
 67    Four,
 68    #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
 69    FourTurbo,
 70    #[serde(rename = "gpt-4o", alias = "gpt-4o")]
 71    #[default]
 72    FourOmni,
 73    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
 74    FourOmniMini,
 75    #[serde(rename = "o1-preview", alias = "o1-preview")]
 76    O1Preview,
 77    #[serde(rename = "o1-mini", alias = "o1-mini")]
 78    O1Mini,
 79
 80    #[serde(rename = "custom")]
 81    Custom {
 82        name: String,
 83        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 84        display_name: Option<String>,
 85        max_tokens: usize,
 86        max_output_tokens: Option<u32>,
 87        max_completion_tokens: Option<u32>,
 88    },
 89}
 90
 91impl Model {
 92    pub fn from_id(id: &str) -> Result<Self> {
 93        match id {
 94            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 95            "gpt-4" => Ok(Self::Four),
 96            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 97            "gpt-4o" => Ok(Self::FourOmni),
 98            "gpt-4o-mini" => Ok(Self::FourOmniMini),
 99            "o1-preview" => Ok(Self::O1Preview),
100            "o1-mini" => Ok(Self::O1Mini),
101            _ => Err(anyhow!("invalid model id")),
102        }
103    }
104
105    pub fn id(&self) -> &str {
106        match self {
107            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
108            Self::Four => "gpt-4",
109            Self::FourTurbo => "gpt-4-turbo",
110            Self::FourOmni => "gpt-4o",
111            Self::FourOmniMini => "gpt-4o-mini",
112            Self::O1Preview => "o1-preview",
113            Self::O1Mini => "o1-mini",
114            Self::Custom { name, .. } => name,
115        }
116    }
117
118    pub fn display_name(&self) -> &str {
119        match self {
120            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
121            Self::Four => "gpt-4",
122            Self::FourTurbo => "gpt-4-turbo",
123            Self::FourOmni => "gpt-4o",
124            Self::FourOmniMini => "gpt-4o-mini",
125            Self::O1Preview => "o1-preview",
126            Self::O1Mini => "o1-mini",
127            Self::Custom {
128                name, display_name, ..
129            } => display_name.as_ref().unwrap_or(name),
130        }
131    }
132
133    pub fn max_token_count(&self) -> usize {
134        match self {
135            Self::ThreePointFiveTurbo => 16385,
136            Self::Four => 8192,
137            Self::FourTurbo => 128000,
138            Self::FourOmni => 128000,
139            Self::FourOmniMini => 128000,
140            Self::O1Preview => 128000,
141            Self::O1Mini => 128000,
142            Self::Custom { max_tokens, .. } => *max_tokens,
143        }
144    }
145
146    pub fn max_output_tokens(&self) -> Option<u32> {
147        match self {
148            Self::Custom {
149                max_output_tokens, ..
150            } => *max_output_tokens,
151            _ => None,
152        }
153    }
154}
155
156#[derive(Debug, Serialize, Deserialize)]
157pub struct Request {
158    pub model: String,
159    pub messages: Vec<RequestMessage>,
160    pub stream: bool,
161    #[serde(default, skip_serializing_if = "Option::is_none")]
162    pub max_tokens: Option<u32>,
163    #[serde(default, skip_serializing_if = "Vec::is_empty")]
164    pub stop: Vec<String>,
165    pub temperature: f32,
166    #[serde(default, skip_serializing_if = "Option::is_none")]
167    pub tool_choice: Option<ToolChoice>,
168    #[serde(default, skip_serializing_if = "Vec::is_empty")]
169    pub tools: Vec<ToolDefinition>,
170}
171
172#[derive(Debug, Serialize, Deserialize)]
173pub struct CompletionRequest {
174    pub model: String,
175    pub prompt: String,
176    pub max_tokens: u32,
177    pub temperature: f32,
178    #[serde(default, skip_serializing_if = "Option::is_none")]
179    pub prediction: Option<Prediction>,
180    #[serde(default, skip_serializing_if = "Option::is_none")]
181    pub rewrite_speculation: Option<bool>,
182}
183
184#[derive(Clone, Deserialize, Serialize, Debug)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum Prediction {
187    Content { content: String },
188}
189
190#[derive(Debug, Serialize, Deserialize)]
191#[serde(untagged)]
192pub enum ToolChoice {
193    Auto,
194    Required,
195    None,
196    Other(ToolDefinition),
197}
198
199#[derive(Clone, Deserialize, Serialize, Debug)]
200#[serde(tag = "type", rename_all = "snake_case")]
201pub enum ToolDefinition {
202    #[allow(dead_code)]
203    Function { function: FunctionDefinition },
204}
205
206#[derive(Clone, Debug, Serialize, Deserialize)]
207pub struct FunctionDefinition {
208    pub name: String,
209    pub description: Option<String>,
210    pub parameters: Option<Value>,
211}
212
213#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
214#[serde(tag = "role", rename_all = "lowercase")]
215pub enum RequestMessage {
216    Assistant {
217        content: Option<String>,
218        #[serde(default, skip_serializing_if = "Vec::is_empty")]
219        tool_calls: Vec<ToolCall>,
220    },
221    User {
222        content: String,
223    },
224    System {
225        content: String,
226    },
227    Tool {
228        content: String,
229        tool_call_id: String,
230    },
231}
232
233#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
234pub struct ToolCall {
235    pub id: String,
236    #[serde(flatten)]
237    pub content: ToolCallContent,
238}
239
240#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
241#[serde(tag = "type", rename_all = "lowercase")]
242pub enum ToolCallContent {
243    Function { function: FunctionContent },
244}
245
246#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
247pub struct FunctionContent {
248    pub name: String,
249    pub arguments: String,
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253pub struct ResponseMessageDelta {
254    pub role: Option<Role>,
255    pub content: Option<String>,
256    #[serde(default, skip_serializing_if = "is_none_or_empty")]
257    pub tool_calls: Option<Vec<ToolCallChunk>>,
258}
259
260#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
261pub struct ToolCallChunk {
262    pub index: usize,
263    pub id: Option<String>,
264
265    // There is also an optional `type` field that would determine if a
266    // function is there. Sometimes this streams in with the `function` before
267    // it streams in the `type`
268    pub function: Option<FunctionChunk>,
269}
270
271#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
272pub struct FunctionChunk {
273    pub name: Option<String>,
274    pub arguments: Option<String>,
275}
276
277#[derive(Serialize, Deserialize, Debug)]
278pub struct Usage {
279    pub prompt_tokens: u32,
280    pub completion_tokens: u32,
281    pub total_tokens: u32,
282}
283
284#[derive(Serialize, Deserialize, Debug)]
285pub struct ChoiceDelta {
286    pub index: u32,
287    pub delta: ResponseMessageDelta,
288    pub finish_reason: Option<String>,
289}
290
291#[derive(Serialize, Deserialize, Debug)]
292#[serde(untagged)]
293pub enum ResponseStreamResult {
294    Ok(ResponseStreamEvent),
295    Err { error: String },
296}
297
298#[derive(Serialize, Deserialize, Debug)]
299pub struct ResponseStreamEvent {
300    pub created: u32,
301    pub model: String,
302    pub choices: Vec<ChoiceDelta>,
303    pub usage: Option<Usage>,
304}
305
306#[derive(Serialize, Deserialize, Debug)]
307pub struct CompletionResponse {
308    pub id: String,
309    pub object: String,
310    pub created: u64,
311    pub model: String,
312    pub choices: Vec<CompletionChoice>,
313    pub usage: Usage,
314}
315
316#[derive(Serialize, Deserialize, Debug)]
317pub struct CompletionChoice {
318    pub text: String,
319}
320
321#[derive(Serialize, Deserialize, Debug)]
322pub struct Response {
323    pub id: String,
324    pub object: String,
325    pub created: u64,
326    pub model: String,
327    pub choices: Vec<Choice>,
328    pub usage: Usage,
329}
330
331#[derive(Serialize, Deserialize, Debug)]
332pub struct Choice {
333    pub index: u32,
334    pub message: RequestMessage,
335    pub finish_reason: Option<String>,
336}
337
338pub async fn complete(
339    client: &dyn HttpClient,
340    api_url: &str,
341    api_key: &str,
342    request: Request,
343) -> Result<Response> {
344    let uri = format!("{api_url}/chat/completions");
345    let request_builder = HttpRequest::builder()
346        .method(Method::POST)
347        .uri(uri)
348        .header("Content-Type", "application/json")
349        .header("Authorization", format!("Bearer {}", api_key));
350
351    let mut request_body = request;
352    request_body.stream = false;
353
354    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
355    let mut response = client.send(request).await?;
356
357    if response.status().is_success() {
358        let mut body = String::new();
359        response.body_mut().read_to_string(&mut body).await?;
360        let response: Response = serde_json::from_str(&body)?;
361        Ok(response)
362    } else {
363        let mut body = String::new();
364        response.body_mut().read_to_string(&mut body).await?;
365
366        #[derive(Deserialize)]
367        struct OpenAiResponse {
368            error: OpenAiError,
369        }
370
371        #[derive(Deserialize)]
372        struct OpenAiError {
373            message: String,
374        }
375
376        match serde_json::from_str::<OpenAiResponse>(&body) {
377            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
378                "Failed to connect to OpenAI API: {}",
379                response.error.message,
380            )),
381
382            _ => Err(anyhow!(
383                "Failed to connect to OpenAI API: {} {}",
384                response.status(),
385                body,
386            )),
387        }
388    }
389}
390
391pub async fn complete_text(
392    client: &dyn HttpClient,
393    api_url: &str,
394    api_key: &str,
395    request: CompletionRequest,
396) -> Result<CompletionResponse> {
397    let uri = format!("{api_url}/completions");
398    let request_builder = HttpRequest::builder()
399        .method(Method::POST)
400        .uri(uri)
401        .header("Content-Type", "application/json")
402        .header("Authorization", format!("Bearer {}", api_key));
403
404    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
405    let mut response = client.send(request).await?;
406
407    if response.status().is_success() {
408        let mut body = String::new();
409        response.body_mut().read_to_string(&mut body).await?;
410        let response = serde_json::from_str(&body)?;
411        Ok(response)
412    } else {
413        let mut body = String::new();
414        response.body_mut().read_to_string(&mut body).await?;
415
416        #[derive(Deserialize)]
417        struct OpenAiResponse {
418            error: OpenAiError,
419        }
420
421        #[derive(Deserialize)]
422        struct OpenAiError {
423            message: String,
424        }
425
426        match serde_json::from_str::<OpenAiResponse>(&body) {
427            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
428                "Failed to connect to OpenAI API: {}",
429                response.error.message,
430            )),
431
432            _ => Err(anyhow!(
433                "Failed to connect to OpenAI API: {} {}",
434                response.status(),
435                body,
436            )),
437        }
438    }
439}
440
441fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
442    ResponseStreamEvent {
443        created: response.created as u32,
444        model: response.model,
445        choices: response
446            .choices
447            .into_iter()
448            .map(|choice| ChoiceDelta {
449                index: choice.index,
450                delta: ResponseMessageDelta {
451                    role: Some(match choice.message {
452                        RequestMessage::Assistant { .. } => Role::Assistant,
453                        RequestMessage::User { .. } => Role::User,
454                        RequestMessage::System { .. } => Role::System,
455                        RequestMessage::Tool { .. } => Role::Tool,
456                    }),
457                    content: match choice.message {
458                        RequestMessage::Assistant { content, .. } => content,
459                        RequestMessage::User { content } => Some(content),
460                        RequestMessage::System { content } => Some(content),
461                        RequestMessage::Tool { content, .. } => Some(content),
462                    },
463                    tool_calls: None,
464                },
465                finish_reason: choice.finish_reason,
466            })
467            .collect(),
468        usage: Some(response.usage),
469    }
470}
471
472pub async fn stream_completion(
473    client: &dyn HttpClient,
474    api_url: &str,
475    api_key: &str,
476    request: Request,
477) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
478    if request.model == "o1-preview" || request.model == "o1-mini" {
479        let response = complete(client, api_url, api_key, request).await;
480        let response_stream_event = response.map(adapt_response_to_stream);
481        return Ok(stream::once(future::ready(response_stream_event)).boxed());
482    }
483
484    let uri = format!("{api_url}/chat/completions");
485    let request_builder = HttpRequest::builder()
486        .method(Method::POST)
487        .uri(uri)
488        .header("Content-Type", "application/json")
489        .header("Authorization", format!("Bearer {}", api_key));
490
491    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
492    let mut response = client.send(request).await?;
493    if response.status().is_success() {
494        let reader = BufReader::new(response.into_body());
495        Ok(reader
496            .lines()
497            .filter_map(|line| async move {
498                match line {
499                    Ok(line) => {
500                        let line = line.strip_prefix("data: ")?;
501                        if line == "[DONE]" {
502                            None
503                        } else {
504                            match serde_json::from_str(line) {
505                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
506                                Ok(ResponseStreamResult::Err { error }) => {
507                                    Some(Err(anyhow!(error)))
508                                }
509                                Err(error) => Some(Err(anyhow!(error))),
510                            }
511                        }
512                    }
513                    Err(error) => Some(Err(anyhow!(error))),
514                }
515            })
516            .boxed())
517    } else {
518        let mut body = String::new();
519        response.body_mut().read_to_string(&mut body).await?;
520
521        #[derive(Deserialize)]
522        struct OpenAiResponse {
523            error: OpenAiError,
524        }
525
526        #[derive(Deserialize)]
527        struct OpenAiError {
528            message: String,
529        }
530
531        match serde_json::from_str::<OpenAiResponse>(&body) {
532            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
533                "Failed to connect to OpenAI API: {}",
534                response.error.message,
535            )),
536
537            _ => Err(anyhow!(
538                "Failed to connect to OpenAI API: {} {}",
539                response.status(),
540                body,
541            )),
542        }
543    }
544}
545
546#[derive(Copy, Clone, Serialize, Deserialize)]
547pub enum OpenAiEmbeddingModel {
548    #[serde(rename = "text-embedding-3-small")]
549    TextEmbedding3Small,
550    #[serde(rename = "text-embedding-3-large")]
551    TextEmbedding3Large,
552}
553
554#[derive(Serialize)]
555struct OpenAiEmbeddingRequest<'a> {
556    model: OpenAiEmbeddingModel,
557    input: Vec<&'a str>,
558}
559
560#[derive(Deserialize)]
561pub struct OpenAiEmbeddingResponse {
562    pub data: Vec<OpenAiEmbedding>,
563}
564
565#[derive(Deserialize)]
566pub struct OpenAiEmbedding {
567    pub embedding: Vec<f32>,
568}
569
570pub fn embed<'a>(
571    client: &dyn HttpClient,
572    api_url: &str,
573    api_key: &str,
574    model: OpenAiEmbeddingModel,
575    texts: impl IntoIterator<Item = &'a str>,
576) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
577    let uri = format!("{api_url}/embeddings");
578
579    let request = OpenAiEmbeddingRequest {
580        model,
581        input: texts.into_iter().collect(),
582    };
583    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
584    let request = HttpRequest::builder()
585        .method(Method::POST)
586        .uri(uri)
587        .header("Content-Type", "application/json")
588        .header("Authorization", format!("Bearer {}", api_key))
589        .body(body)
590        .map(|request| client.send(request));
591
592    async move {
593        let mut response = request?.await?;
594        let mut body = String::new();
595        response.body_mut().read_to_string(&mut body).await?;
596
597        if response.status().is_success() {
598            let response: OpenAiEmbeddingResponse =
599                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
600            Ok(response)
601        } else {
602            Err(anyhow!(
603                "error during embedding, status: {:?}, body: {:?}",
604                response.status(),
605                body
606            ))
607        }
608    }
609}
610
611pub async fn extract_tool_args_from_events(
612    tool_name: String,
613    mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
614) -> Result<impl Send + Stream<Item = Result<String>>> {
615    let mut tool_use_index = None;
616    let mut first_chunk = None;
617    while let Some(event) = events.next().await {
618        let call = event?.choices.into_iter().find_map(|choice| {
619            choice.delta.tool_calls?.into_iter().find_map(|call| {
620                if call.function.as_ref()?.name.as_deref()? == tool_name {
621                    Some(call)
622                } else {
623                    None
624                }
625            })
626        });
627        if let Some(call) = call {
628            tool_use_index = Some(call.index);
629            first_chunk = call.function.and_then(|func| func.arguments);
630            break;
631        }
632    }
633
634    let Some(tool_use_index) = tool_use_index else {
635        return Err(anyhow!("tool not used"));
636    };
637
638    Ok(events.filter_map(move |event| {
639        let result = match event {
640            Err(error) => Some(Err(error)),
641            Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
642                choice.delta.tool_calls?.into_iter().find_map(|call| {
643                    if call.index == tool_use_index {
644                        let func = call.function?;
645                        let mut arguments = func.arguments?;
646                        if let Some(mut first_chunk) = first_chunk.take() {
647                            first_chunk.push_str(&arguments);
648                            arguments = first_chunk
649                        }
650                        Some(Ok(arguments))
651                    } else {
652                        None
653                    }
654                })
655            }),
656        };
657
658        async move { result }
659    }))
660}
661
662pub fn extract_text_from_events(
663    response: impl Stream<Item = Result<ResponseStreamEvent>>,
664) -> impl Stream<Item = Result<String>> {
665    response.filter_map(|response| async move {
666        match response {
667            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
668            Err(error) => Some(Err(error)),
669        }
670    })
671}