open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use futures::{
  5    io::BufReader,
  6    stream::{self, BoxStream},
  7    AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
  8};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12use std::{
 13    convert::TryFrom,
 14    future::{self, Future},
 15    pin::Pin,
 16};
 17use strum::EnumIter;
 18
 19pub use supported_countries::*;
 20
 21pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 22
 23fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 24    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 25}
 26
 27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 28#[serde(rename_all = "lowercase")]
 29pub enum Role {
 30    User,
 31    Assistant,
 32    System,
 33    Tool,
 34}
 35
 36impl TryFrom<String> for Role {
 37    type Error = anyhow::Error;
 38
 39    fn try_from(value: String) -> Result<Self> {
 40        match value.as_str() {
 41            "user" => Ok(Self::User),
 42            "assistant" => Ok(Self::Assistant),
 43            "system" => Ok(Self::System),
 44            "tool" => Ok(Self::Tool),
 45            _ => Err(anyhow!("invalid role '{value}'")),
 46        }
 47    }
 48}
 49
 50impl From<Role> for String {
 51    fn from(val: Role) -> Self {
 52        match val {
 53            Role::User => "user".to_owned(),
 54            Role::Assistant => "assistant".to_owned(),
 55            Role::System => "system".to_owned(),
 56            Role::Tool => "tool".to_owned(),
 57        }
 58    }
 59}
 60
 61#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 62#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 63pub enum Model {
 64    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
 65    ThreePointFiveTurbo,
 66    #[serde(rename = "gpt-4", alias = "gpt-4")]
 67    Four,
 68    #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
 69    FourTurbo,
 70    #[serde(rename = "gpt-4o", alias = "gpt-4o")]
 71    #[default]
 72    FourOmni,
 73    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
 74    FourOmniMini,
 75    #[serde(rename = "o1", alias = "o1")]
 76    O1,
 77    #[serde(rename = "o1-preview", alias = "o1-preview")]
 78    O1Preview,
 79    #[serde(rename = "o1-mini", alias = "o1-mini")]
 80    O1Mini,
 81
 82    #[serde(rename = "custom")]
 83    Custom {
 84        name: String,
 85        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 86        display_name: Option<String>,
 87        max_tokens: usize,
 88        max_output_tokens: Option<u32>,
 89        max_completion_tokens: Option<u32>,
 90    },
 91}
 92
 93impl Model {
 94    pub fn from_id(id: &str) -> Result<Self> {
 95        match id {
 96            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 97            "gpt-4" => Ok(Self::Four),
 98            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 99            "gpt-4o" => Ok(Self::FourOmni),
100            "gpt-4o-mini" => Ok(Self::FourOmniMini),
101            "o1" => Ok(Self::O1),
102            "o1-preview" => Ok(Self::O1Preview),
103            "o1-mini" => Ok(Self::O1Mini),
104            _ => Err(anyhow!("invalid model id")),
105        }
106    }
107
108    pub fn id(&self) -> &str {
109        match self {
110            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
111            Self::Four => "gpt-4",
112            Self::FourTurbo => "gpt-4-turbo",
113            Self::FourOmni => "gpt-4o",
114            Self::FourOmniMini => "gpt-4o-mini",
115            Self::O1 => "o1",
116            Self::O1Preview => "o1-preview",
117            Self::O1Mini => "o1-mini",
118            Self::Custom { name, .. } => name,
119        }
120    }
121
122    pub fn display_name(&self) -> &str {
123        match self {
124            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
125            Self::Four => "gpt-4",
126            Self::FourTurbo => "gpt-4-turbo",
127            Self::FourOmni => "gpt-4o",
128            Self::FourOmniMini => "gpt-4o-mini",
129            Self::O1 => "o1",
130            Self::O1Preview => "o1-preview",
131            Self::O1Mini => "o1-mini",
132            Self::Custom {
133                name, display_name, ..
134            } => display_name.as_ref().unwrap_or(name),
135        }
136    }
137
138    pub fn max_token_count(&self) -> usize {
139        match self {
140            Self::ThreePointFiveTurbo => 16385,
141            Self::Four => 8192,
142            Self::FourTurbo => 128000,
143            Self::FourOmni => 128000,
144            Self::FourOmniMini => 128000,
145            Self::O1 => 200000,
146            Self::O1Preview => 128000,
147            Self::O1Mini => 128000,
148            Self::Custom { max_tokens, .. } => *max_tokens,
149        }
150    }
151
152    pub fn max_output_tokens(&self) -> Option<u32> {
153        match self {
154            Self::Custom {
155                max_output_tokens, ..
156            } => *max_output_tokens,
157            _ => None,
158        }
159    }
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163pub struct Request {
164    pub model: String,
165    pub messages: Vec<RequestMessage>,
166    pub stream: bool,
167    #[serde(default, skip_serializing_if = "Option::is_none")]
168    pub max_tokens: Option<u32>,
169    #[serde(default, skip_serializing_if = "Vec::is_empty")]
170    pub stop: Vec<String>,
171    pub temperature: f32,
172    #[serde(default, skip_serializing_if = "Option::is_none")]
173    pub tool_choice: Option<ToolChoice>,
174    #[serde(default, skip_serializing_if = "Vec::is_empty")]
175    pub tools: Vec<ToolDefinition>,
176}
177
178#[derive(Debug, Serialize, Deserialize)]
179pub struct CompletionRequest {
180    pub model: String,
181    pub prompt: String,
182    pub max_tokens: u32,
183    pub temperature: f32,
184    #[serde(default, skip_serializing_if = "Option::is_none")]
185    pub prediction: Option<Prediction>,
186    #[serde(default, skip_serializing_if = "Option::is_none")]
187    pub rewrite_speculation: Option<bool>,
188}
189
190#[derive(Clone, Deserialize, Serialize, Debug)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum Prediction {
193    Content { content: String },
194}
195
196#[derive(Debug, Serialize, Deserialize)]
197#[serde(untagged)]
198pub enum ToolChoice {
199    Auto,
200    Required,
201    None,
202    Other(ToolDefinition),
203}
204
205#[derive(Clone, Deserialize, Serialize, Debug)]
206#[serde(tag = "type", rename_all = "snake_case")]
207pub enum ToolDefinition {
208    #[allow(dead_code)]
209    Function { function: FunctionDefinition },
210}
211
212#[derive(Clone, Debug, Serialize, Deserialize)]
213pub struct FunctionDefinition {
214    pub name: String,
215    pub description: Option<String>,
216    pub parameters: Option<Value>,
217}
218
219#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
220#[serde(tag = "role", rename_all = "lowercase")]
221pub enum RequestMessage {
222    Assistant {
223        content: Option<String>,
224        #[serde(default, skip_serializing_if = "Vec::is_empty")]
225        tool_calls: Vec<ToolCall>,
226    },
227    User {
228        content: String,
229    },
230    System {
231        content: String,
232    },
233    Tool {
234        content: String,
235        tool_call_id: String,
236    },
237}
238
239#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
240pub struct ToolCall {
241    pub id: String,
242    #[serde(flatten)]
243    pub content: ToolCallContent,
244}
245
246#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
247#[serde(tag = "type", rename_all = "lowercase")]
248pub enum ToolCallContent {
249    Function { function: FunctionContent },
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253pub struct FunctionContent {
254    pub name: String,
255    pub arguments: String,
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259pub struct ResponseMessageDelta {
260    pub role: Option<Role>,
261    pub content: Option<String>,
262    #[serde(default, skip_serializing_if = "is_none_or_empty")]
263    pub tool_calls: Option<Vec<ToolCallChunk>>,
264}
265
266#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
267pub struct ToolCallChunk {
268    pub index: usize,
269    pub id: Option<String>,
270
271    // There is also an optional `type` field that would determine if a
272    // function is there. Sometimes this streams in with the `function` before
273    // it streams in the `type`
274    pub function: Option<FunctionChunk>,
275}
276
277#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
278pub struct FunctionChunk {
279    pub name: Option<String>,
280    pub arguments: Option<String>,
281}
282
283#[derive(Serialize, Deserialize, Debug)]
284pub struct Usage {
285    pub prompt_tokens: u32,
286    pub completion_tokens: u32,
287    pub total_tokens: u32,
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291pub struct ChoiceDelta {
292    pub index: u32,
293    pub delta: ResponseMessageDelta,
294    pub finish_reason: Option<String>,
295}
296
297#[derive(Serialize, Deserialize, Debug)]
298#[serde(untagged)]
299pub enum ResponseStreamResult {
300    Ok(ResponseStreamEvent),
301    Err { error: String },
302}
303
304#[derive(Serialize, Deserialize, Debug)]
305pub struct ResponseStreamEvent {
306    pub created: u32,
307    pub model: String,
308    pub choices: Vec<ChoiceDelta>,
309    pub usage: Option<Usage>,
310}
311
312#[derive(Serialize, Deserialize, Debug)]
313pub struct CompletionResponse {
314    pub id: String,
315    pub object: String,
316    pub created: u64,
317    pub model: String,
318    pub choices: Vec<CompletionChoice>,
319    pub usage: Usage,
320}
321
322#[derive(Serialize, Deserialize, Debug)]
323pub struct CompletionChoice {
324    pub text: String,
325}
326
327#[derive(Serialize, Deserialize, Debug)]
328pub struct Response {
329    pub id: String,
330    pub object: String,
331    pub created: u64,
332    pub model: String,
333    pub choices: Vec<Choice>,
334    pub usage: Usage,
335}
336
337#[derive(Serialize, Deserialize, Debug)]
338pub struct Choice {
339    pub index: u32,
340    pub message: RequestMessage,
341    pub finish_reason: Option<String>,
342}
343
344pub async fn complete(
345    client: &dyn HttpClient,
346    api_url: &str,
347    api_key: &str,
348    request: Request,
349) -> Result<Response> {
350    let uri = format!("{api_url}/chat/completions");
351    let request_builder = HttpRequest::builder()
352        .method(Method::POST)
353        .uri(uri)
354        .header("Content-Type", "application/json")
355        .header("Authorization", format!("Bearer {}", api_key));
356
357    let mut request_body = request;
358    request_body.stream = false;
359
360    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
361    let mut response = client.send(request).await?;
362
363    if response.status().is_success() {
364        let mut body = String::new();
365        response.body_mut().read_to_string(&mut body).await?;
366        let response: Response = serde_json::from_str(&body)?;
367        Ok(response)
368    } else {
369        let mut body = String::new();
370        response.body_mut().read_to_string(&mut body).await?;
371
372        #[derive(Deserialize)]
373        struct OpenAiResponse {
374            error: OpenAiError,
375        }
376
377        #[derive(Deserialize)]
378        struct OpenAiError {
379            message: String,
380        }
381
382        match serde_json::from_str::<OpenAiResponse>(&body) {
383            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
384                "Failed to connect to OpenAI API: {}",
385                response.error.message,
386            )),
387
388            _ => Err(anyhow!(
389                "Failed to connect to OpenAI API: {} {}",
390                response.status(),
391                body,
392            )),
393        }
394    }
395}
396
397pub async fn complete_text(
398    client: &dyn HttpClient,
399    api_url: &str,
400    api_key: &str,
401    request: CompletionRequest,
402) -> Result<CompletionResponse> {
403    let uri = format!("{api_url}/completions");
404    let request_builder = HttpRequest::builder()
405        .method(Method::POST)
406        .uri(uri)
407        .header("Content-Type", "application/json")
408        .header("Authorization", format!("Bearer {}", api_key));
409
410    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
411    let mut response = client.send(request).await?;
412
413    if response.status().is_success() {
414        let mut body = String::new();
415        response.body_mut().read_to_string(&mut body).await?;
416        let response = serde_json::from_str(&body)?;
417        Ok(response)
418    } else {
419        let mut body = String::new();
420        response.body_mut().read_to_string(&mut body).await?;
421
422        #[derive(Deserialize)]
423        struct OpenAiResponse {
424            error: OpenAiError,
425        }
426
427        #[derive(Deserialize)]
428        struct OpenAiError {
429            message: String,
430        }
431
432        match serde_json::from_str::<OpenAiResponse>(&body) {
433            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
434                "Failed to connect to OpenAI API: {}",
435                response.error.message,
436            )),
437
438            _ => Err(anyhow!(
439                "Failed to connect to OpenAI API: {} {}",
440                response.status(),
441                body,
442            )),
443        }
444    }
445}
446
447fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
448    ResponseStreamEvent {
449        created: response.created as u32,
450        model: response.model,
451        choices: response
452            .choices
453            .into_iter()
454            .map(|choice| ChoiceDelta {
455                index: choice.index,
456                delta: ResponseMessageDelta {
457                    role: Some(match choice.message {
458                        RequestMessage::Assistant { .. } => Role::Assistant,
459                        RequestMessage::User { .. } => Role::User,
460                        RequestMessage::System { .. } => Role::System,
461                        RequestMessage::Tool { .. } => Role::Tool,
462                    }),
463                    content: match choice.message {
464                        RequestMessage::Assistant { content, .. } => content,
465                        RequestMessage::User { content } => Some(content),
466                        RequestMessage::System { content } => Some(content),
467                        RequestMessage::Tool { content, .. } => Some(content),
468                    },
469                    tool_calls: None,
470                },
471                finish_reason: choice.finish_reason,
472            })
473            .collect(),
474        usage: Some(response.usage),
475    }
476}
477
478pub async fn stream_completion(
479    client: &dyn HttpClient,
480    api_url: &str,
481    api_key: &str,
482    request: Request,
483) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
484    if request.model.starts_with("o1") {
485        let response = complete(client, api_url, api_key, request).await;
486        let response_stream_event = response.map(adapt_response_to_stream);
487        return Ok(stream::once(future::ready(response_stream_event)).boxed());
488    }
489
490    let uri = format!("{api_url}/chat/completions");
491    let request_builder = HttpRequest::builder()
492        .method(Method::POST)
493        .uri(uri)
494        .header("Content-Type", "application/json")
495        .header("Authorization", format!("Bearer {}", api_key));
496
497    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
498    let mut response = client.send(request).await?;
499    if response.status().is_success() {
500        let reader = BufReader::new(response.into_body());
501        Ok(reader
502            .lines()
503            .filter_map(|line| async move {
504                match line {
505                    Ok(line) => {
506                        let line = line.strip_prefix("data: ")?;
507                        if line == "[DONE]" {
508                            None
509                        } else {
510                            match serde_json::from_str(line) {
511                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
512                                Ok(ResponseStreamResult::Err { error }) => {
513                                    Some(Err(anyhow!(error)))
514                                }
515                                Err(error) => Some(Err(anyhow!(error))),
516                            }
517                        }
518                    }
519                    Err(error) => Some(Err(anyhow!(error))),
520                }
521            })
522            .boxed())
523    } else {
524        let mut body = String::new();
525        response.body_mut().read_to_string(&mut body).await?;
526
527        #[derive(Deserialize)]
528        struct OpenAiResponse {
529            error: OpenAiError,
530        }
531
532        #[derive(Deserialize)]
533        struct OpenAiError {
534            message: String,
535        }
536
537        match serde_json::from_str::<OpenAiResponse>(&body) {
538            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
539                "Failed to connect to OpenAI API: {}",
540                response.error.message,
541            )),
542
543            _ => Err(anyhow!(
544                "Failed to connect to OpenAI API: {} {}",
545                response.status(),
546                body,
547            )),
548        }
549    }
550}
551
552#[derive(Copy, Clone, Serialize, Deserialize)]
553pub enum OpenAiEmbeddingModel {
554    #[serde(rename = "text-embedding-3-small")]
555    TextEmbedding3Small,
556    #[serde(rename = "text-embedding-3-large")]
557    TextEmbedding3Large,
558}
559
560#[derive(Serialize)]
561struct OpenAiEmbeddingRequest<'a> {
562    model: OpenAiEmbeddingModel,
563    input: Vec<&'a str>,
564}
565
566#[derive(Deserialize)]
567pub struct OpenAiEmbeddingResponse {
568    pub data: Vec<OpenAiEmbedding>,
569}
570
571#[derive(Deserialize)]
572pub struct OpenAiEmbedding {
573    pub embedding: Vec<f32>,
574}
575
576pub fn embed<'a>(
577    client: &dyn HttpClient,
578    api_url: &str,
579    api_key: &str,
580    model: OpenAiEmbeddingModel,
581    texts: impl IntoIterator<Item = &'a str>,
582) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
583    let uri = format!("{api_url}/embeddings");
584
585    let request = OpenAiEmbeddingRequest {
586        model,
587        input: texts.into_iter().collect(),
588    };
589    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
590    let request = HttpRequest::builder()
591        .method(Method::POST)
592        .uri(uri)
593        .header("Content-Type", "application/json")
594        .header("Authorization", format!("Bearer {}", api_key))
595        .body(body)
596        .map(|request| client.send(request));
597
598    async move {
599        let mut response = request?.await?;
600        let mut body = String::new();
601        response.body_mut().read_to_string(&mut body).await?;
602
603        if response.status().is_success() {
604            let response: OpenAiEmbeddingResponse =
605                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
606            Ok(response)
607        } else {
608            Err(anyhow!(
609                "error during embedding, status: {:?}, body: {:?}",
610                response.status(),
611                body
612            ))
613        }
614    }
615}
616
617pub async fn extract_tool_args_from_events(
618    tool_name: String,
619    mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
620) -> Result<impl Send + Stream<Item = Result<String>>> {
621    let mut tool_use_index = None;
622    let mut first_chunk = None;
623    while let Some(event) = events.next().await {
624        let call = event?.choices.into_iter().find_map(|choice| {
625            choice.delta.tool_calls?.into_iter().find_map(|call| {
626                if call.function.as_ref()?.name.as_deref()? == tool_name {
627                    Some(call)
628                } else {
629                    None
630                }
631            })
632        });
633        if let Some(call) = call {
634            tool_use_index = Some(call.index);
635            first_chunk = call.function.and_then(|func| func.arguments);
636            break;
637        }
638    }
639
640    let Some(tool_use_index) = tool_use_index else {
641        return Err(anyhow!("tool not used"));
642    };
643
644    Ok(events.filter_map(move |event| {
645        let result = match event {
646            Err(error) => Some(Err(error)),
647            Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
648                choice.delta.tool_calls?.into_iter().find_map(|call| {
649                    if call.index == tool_use_index {
650                        let func = call.function?;
651                        let mut arguments = func.arguments?;
652                        if let Some(mut first_chunk) = first_chunk.take() {
653                            first_chunk.push_str(&arguments);
654                            arguments = first_chunk
655                        }
656                        Some(Ok(arguments))
657                    } else {
658                        None
659                    }
660                })
661            }),
662        };
663
664        async move { result }
665    }))
666}
667
668pub fn extract_text_from_events(
669    response: impl Stream<Item = Result<ResponseStreamEvent>>,
670) -> impl Stream<Item = Result<String>> {
671    response.filter_map(|response| async move {
672        match response {
673            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
674            Err(error) => Some(Err(error)),
675        }
676    })
677}