open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use futures::{
  5    io::BufReader,
  6    stream::{self, BoxStream},
  7    AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
  8};
  9use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12use std::{
 13    convert::TryFrom,
 14    future::{self, Future},
 15    pin::Pin,
 16    time::Duration,
 17};
 18use strum::EnumIter;
 19
 20pub use supported_countries::*;
 21
 22pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 23
 24fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 25    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 26}
 27
 28#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 29#[serde(rename_all = "lowercase")]
 30pub enum Role {
 31    User,
 32    Assistant,
 33    System,
 34    Tool,
 35}
 36
 37impl TryFrom<String> for Role {
 38    type Error = anyhow::Error;
 39
 40    fn try_from(value: String) -> Result<Self> {
 41        match value.as_str() {
 42            "user" => Ok(Self::User),
 43            "assistant" => Ok(Self::Assistant),
 44            "system" => Ok(Self::System),
 45            "tool" => Ok(Self::Tool),
 46            _ => Err(anyhow!("invalid role '{value}'")),
 47        }
 48    }
 49}
 50
 51impl From<Role> for String {
 52    fn from(val: Role) -> Self {
 53        match val {
 54            Role::User => "user".to_owned(),
 55            Role::Assistant => "assistant".to_owned(),
 56            Role::System => "system".to_owned(),
 57            Role::Tool => "tool".to_owned(),
 58        }
 59    }
 60}
 61
 62#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 63#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 64pub enum Model {
 65    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
 66    ThreePointFiveTurbo,
 67    #[serde(rename = "gpt-4", alias = "gpt-4")]
 68    Four,
 69    #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
 70    FourTurbo,
 71    #[serde(rename = "gpt-4o", alias = "gpt-4o")]
 72    #[default]
 73    FourOmni,
 74    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
 75    FourOmniMini,
 76    #[serde(rename = "o1-preview", alias = "o1-preview")]
 77    O1Preview,
 78    #[serde(rename = "o1-mini", alias = "o1-mini")]
 79    O1Mini,
 80
 81    #[serde(rename = "custom")]
 82    Custom {
 83        name: String,
 84        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 85        display_name: Option<String>,
 86        max_tokens: usize,
 87        max_output_tokens: Option<u32>,
 88        max_completion_tokens: Option<u32>,
 89    },
 90}
 91
 92impl Model {
 93    pub fn from_id(id: &str) -> Result<Self> {
 94        match id {
 95            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 96            "gpt-4" => Ok(Self::Four),
 97            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 98            "gpt-4o" => Ok(Self::FourOmni),
 99            "gpt-4o-mini" => Ok(Self::FourOmniMini),
100            "o1-preview" => Ok(Self::O1Preview),
101            "o1-mini" => Ok(Self::O1Mini),
102            _ => Err(anyhow!("invalid model id")),
103        }
104    }
105
106    pub fn id(&self) -> &str {
107        match self {
108            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
109            Self::Four => "gpt-4",
110            Self::FourTurbo => "gpt-4-turbo",
111            Self::FourOmni => "gpt-4o",
112            Self::FourOmniMini => "gpt-4o-mini",
113            Self::O1Preview => "o1-preview",
114            Self::O1Mini => "o1-mini",
115            Self::Custom { name, .. } => name,
116        }
117    }
118
119    pub fn display_name(&self) -> &str {
120        match self {
121            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
122            Self::Four => "gpt-4",
123            Self::FourTurbo => "gpt-4-turbo",
124            Self::FourOmni => "gpt-4o",
125            Self::FourOmniMini => "gpt-4o-mini",
126            Self::O1Preview => "o1-preview",
127            Self::O1Mini => "o1-mini",
128            Self::Custom {
129                name, display_name, ..
130            } => display_name.as_ref().unwrap_or(name),
131        }
132    }
133
134    pub fn max_token_count(&self) -> usize {
135        match self {
136            Self::ThreePointFiveTurbo => 16385,
137            Self::Four => 8192,
138            Self::FourTurbo => 128000,
139            Self::FourOmni => 128000,
140            Self::FourOmniMini => 128000,
141            Self::O1Preview => 128000,
142            Self::O1Mini => 128000,
143            Self::Custom { max_tokens, .. } => *max_tokens,
144        }
145    }
146
147    pub fn max_output_tokens(&self) -> Option<u32> {
148        match self {
149            Self::Custom {
150                max_output_tokens, ..
151            } => *max_output_tokens,
152            _ => None,
153        }
154    }
155}
156
157#[derive(Debug, Serialize, Deserialize)]
158pub struct Request {
159    pub model: String,
160    pub messages: Vec<RequestMessage>,
161    pub stream: bool,
162    #[serde(default, skip_serializing_if = "Option::is_none")]
163    pub max_tokens: Option<u32>,
164    #[serde(default, skip_serializing_if = "Vec::is_empty")]
165    pub stop: Vec<String>,
166    pub temperature: f32,
167    #[serde(default, skip_serializing_if = "Option::is_none")]
168    pub tool_choice: Option<ToolChoice>,
169    #[serde(default, skip_serializing_if = "Vec::is_empty")]
170    pub tools: Vec<ToolDefinition>,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174#[serde(untagged)]
175pub enum ToolChoice {
176    Auto,
177    Required,
178    None,
179    Other(ToolDefinition),
180}
181
182#[derive(Clone, Deserialize, Serialize, Debug)]
183#[serde(tag = "type", rename_all = "snake_case")]
184pub enum ToolDefinition {
185    #[allow(dead_code)]
186    Function { function: FunctionDefinition },
187}
188
189#[derive(Clone, Debug, Serialize, Deserialize)]
190pub struct FunctionDefinition {
191    pub name: String,
192    pub description: Option<String>,
193    pub parameters: Option<Value>,
194}
195
196#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
197#[serde(tag = "role", rename_all = "lowercase")]
198pub enum RequestMessage {
199    Assistant {
200        content: Option<String>,
201        #[serde(default, skip_serializing_if = "Vec::is_empty")]
202        tool_calls: Vec<ToolCall>,
203    },
204    User {
205        content: String,
206    },
207    System {
208        content: String,
209    },
210    Tool {
211        content: String,
212        tool_call_id: String,
213    },
214}
215
216#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
217pub struct ToolCall {
218    pub id: String,
219    #[serde(flatten)]
220    pub content: ToolCallContent,
221}
222
223#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
224#[serde(tag = "type", rename_all = "lowercase")]
225pub enum ToolCallContent {
226    Function { function: FunctionContent },
227}
228
229#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
230pub struct FunctionContent {
231    pub name: String,
232    pub arguments: String,
233}
234
235#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
236pub struct ResponseMessageDelta {
237    pub role: Option<Role>,
238    pub content: Option<String>,
239    #[serde(default, skip_serializing_if = "is_none_or_empty")]
240    pub tool_calls: Option<Vec<ToolCallChunk>>,
241}
242
243#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
244pub struct ToolCallChunk {
245    pub index: usize,
246    pub id: Option<String>,
247
248    // There is also an optional `type` field that would determine if a
249    // function is there. Sometimes this streams in with the `function` before
250    // it streams in the `type`
251    pub function: Option<FunctionChunk>,
252}
253
254#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
255pub struct FunctionChunk {
256    pub name: Option<String>,
257    pub arguments: Option<String>,
258}
259
260#[derive(Serialize, Deserialize, Debug)]
261pub struct Usage {
262    pub prompt_tokens: u32,
263    pub completion_tokens: u32,
264    pub total_tokens: u32,
265}
266
267#[derive(Serialize, Deserialize, Debug)]
268pub struct ChoiceDelta {
269    pub index: u32,
270    pub delta: ResponseMessageDelta,
271    pub finish_reason: Option<String>,
272}
273
274#[derive(Serialize, Deserialize, Debug)]
275#[serde(untagged)]
276pub enum ResponseStreamResult {
277    Ok(ResponseStreamEvent),
278    Err { error: String },
279}
280
281#[derive(Serialize, Deserialize, Debug)]
282pub struct ResponseStreamEvent {
283    pub created: u32,
284    pub model: String,
285    pub choices: Vec<ChoiceDelta>,
286    pub usage: Option<Usage>,
287}
288
289#[derive(Serialize, Deserialize, Debug)]
290pub struct Response {
291    pub id: String,
292    pub object: String,
293    pub created: u64,
294    pub model: String,
295    pub choices: Vec<Choice>,
296    pub usage: Usage,
297}
298
299#[derive(Serialize, Deserialize, Debug)]
300pub struct Choice {
301    pub index: u32,
302    pub message: RequestMessage,
303    pub finish_reason: Option<String>,
304}
305
306pub async fn complete(
307    client: &dyn HttpClient,
308    api_url: &str,
309    api_key: &str,
310    request: Request,
311    low_speed_timeout: Option<Duration>,
312) -> Result<Response> {
313    let uri = format!("{api_url}/chat/completions");
314    let mut request_builder = HttpRequest::builder()
315        .method(Method::POST)
316        .uri(uri)
317        .header("Content-Type", "application/json")
318        .header("Authorization", format!("Bearer {}", api_key));
319    if let Some(low_speed_timeout) = low_speed_timeout {
320        request_builder = request_builder.read_timeout(low_speed_timeout);
321    };
322
323    let mut request_body = request;
324    request_body.stream = false;
325
326    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
327    let mut response = client.send(request).await?;
328
329    if response.status().is_success() {
330        let mut body = String::new();
331        response.body_mut().read_to_string(&mut body).await?;
332        let response: Response = serde_json::from_str(&body)?;
333        Ok(response)
334    } else {
335        let mut body = String::new();
336        response.body_mut().read_to_string(&mut body).await?;
337
338        #[derive(Deserialize)]
339        struct OpenAiResponse {
340            error: OpenAiError,
341        }
342
343        #[derive(Deserialize)]
344        struct OpenAiError {
345            message: String,
346        }
347
348        match serde_json::from_str::<OpenAiResponse>(&body) {
349            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
350                "Failed to connect to OpenAI API: {}",
351                response.error.message,
352            )),
353
354            _ => Err(anyhow!(
355                "Failed to connect to OpenAI API: {} {}",
356                response.status(),
357                body,
358            )),
359        }
360    }
361}
362
363fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
364    ResponseStreamEvent {
365        created: response.created as u32,
366        model: response.model,
367        choices: response
368            .choices
369            .into_iter()
370            .map(|choice| ChoiceDelta {
371                index: choice.index,
372                delta: ResponseMessageDelta {
373                    role: Some(match choice.message {
374                        RequestMessage::Assistant { .. } => Role::Assistant,
375                        RequestMessage::User { .. } => Role::User,
376                        RequestMessage::System { .. } => Role::System,
377                        RequestMessage::Tool { .. } => Role::Tool,
378                    }),
379                    content: match choice.message {
380                        RequestMessage::Assistant { content, .. } => content,
381                        RequestMessage::User { content } => Some(content),
382                        RequestMessage::System { content } => Some(content),
383                        RequestMessage::Tool { content, .. } => Some(content),
384                    },
385                    tool_calls: None,
386                },
387                finish_reason: choice.finish_reason,
388            })
389            .collect(),
390        usage: Some(response.usage),
391    }
392}
393
394pub async fn stream_completion(
395    client: &dyn HttpClient,
396    api_url: &str,
397    api_key: &str,
398    request: Request,
399    low_speed_timeout: Option<Duration>,
400) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
401    if request.model == "o1-preview" || request.model == "o1-mini" {
402        let response = complete(client, api_url, api_key, request, low_speed_timeout).await;
403        let response_stream_event = response.map(adapt_response_to_stream);
404        return Ok(stream::once(future::ready(response_stream_event)).boxed());
405    }
406
407    let uri = format!("{api_url}/chat/completions");
408    let mut request_builder = HttpRequest::builder()
409        .method(Method::POST)
410        .uri(uri)
411        .header("Content-Type", "application/json")
412        .header("Authorization", format!("Bearer {}", api_key));
413
414    if let Some(low_speed_timeout) = low_speed_timeout {
415        request_builder = request_builder.read_timeout(low_speed_timeout);
416    };
417
418    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
419    let mut response = client.send(request).await?;
420    if response.status().is_success() {
421        let reader = BufReader::new(response.into_body());
422        Ok(reader
423            .lines()
424            .filter_map(|line| async move {
425                match line {
426                    Ok(line) => {
427                        let line = line.strip_prefix("data: ")?;
428                        if line == "[DONE]" {
429                            None
430                        } else {
431                            match serde_json::from_str(line) {
432                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
433                                Ok(ResponseStreamResult::Err { error }) => {
434                                    Some(Err(anyhow!(error)))
435                                }
436                                Err(error) => Some(Err(anyhow!(error))),
437                            }
438                        }
439                    }
440                    Err(error) => Some(Err(anyhow!(error))),
441                }
442            })
443            .boxed())
444    } else {
445        let mut body = String::new();
446        response.body_mut().read_to_string(&mut body).await?;
447
448        #[derive(Deserialize)]
449        struct OpenAiResponse {
450            error: OpenAiError,
451        }
452
453        #[derive(Deserialize)]
454        struct OpenAiError {
455            message: String,
456        }
457
458        match serde_json::from_str::<OpenAiResponse>(&body) {
459            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
460                "Failed to connect to OpenAI API: {}",
461                response.error.message,
462            )),
463
464            _ => Err(anyhow!(
465                "Failed to connect to OpenAI API: {} {}",
466                response.status(),
467                body,
468            )),
469        }
470    }
471}
472
473#[derive(Copy, Clone, Serialize, Deserialize)]
474pub enum OpenAiEmbeddingModel {
475    #[serde(rename = "text-embedding-3-small")]
476    TextEmbedding3Small,
477    #[serde(rename = "text-embedding-3-large")]
478    TextEmbedding3Large,
479}
480
481#[derive(Serialize)]
482struct OpenAiEmbeddingRequest<'a> {
483    model: OpenAiEmbeddingModel,
484    input: Vec<&'a str>,
485}
486
487#[derive(Deserialize)]
488pub struct OpenAiEmbeddingResponse {
489    pub data: Vec<OpenAiEmbedding>,
490}
491
492#[derive(Deserialize)]
493pub struct OpenAiEmbedding {
494    pub embedding: Vec<f32>,
495}
496
497pub fn embed<'a>(
498    client: &dyn HttpClient,
499    api_url: &str,
500    api_key: &str,
501    model: OpenAiEmbeddingModel,
502    texts: impl IntoIterator<Item = &'a str>,
503) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
504    let uri = format!("{api_url}/embeddings");
505
506    let request = OpenAiEmbeddingRequest {
507        model,
508        input: texts.into_iter().collect(),
509    };
510    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
511    let request = HttpRequest::builder()
512        .method(Method::POST)
513        .uri(uri)
514        .header("Content-Type", "application/json")
515        .header("Authorization", format!("Bearer {}", api_key))
516        .body(body)
517        .map(|request| client.send(request));
518
519    async move {
520        let mut response = request?.await?;
521        let mut body = String::new();
522        response.body_mut().read_to_string(&mut body).await?;
523
524        if response.status().is_success() {
525            let response: OpenAiEmbeddingResponse =
526                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
527            Ok(response)
528        } else {
529            Err(anyhow!(
530                "error during embedding, status: {:?}, body: {:?}",
531                response.status(),
532                body
533            ))
534        }
535    }
536}
537
538pub async fn extract_tool_args_from_events(
539    tool_name: String,
540    mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
541) -> Result<impl Send + Stream<Item = Result<String>>> {
542    let mut tool_use_index = None;
543    let mut first_chunk = None;
544    while let Some(event) = events.next().await {
545        let call = event?.choices.into_iter().find_map(|choice| {
546            choice.delta.tool_calls?.into_iter().find_map(|call| {
547                if call.function.as_ref()?.name.as_deref()? == tool_name {
548                    Some(call)
549                } else {
550                    None
551                }
552            })
553        });
554        if let Some(call) = call {
555            tool_use_index = Some(call.index);
556            first_chunk = call.function.and_then(|func| func.arguments);
557            break;
558        }
559    }
560
561    let Some(tool_use_index) = tool_use_index else {
562        return Err(anyhow!("tool not used"));
563    };
564
565    Ok(events.filter_map(move |event| {
566        let result = match event {
567            Err(error) => Some(Err(error)),
568            Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
569                choice.delta.tool_calls?.into_iter().find_map(|call| {
570                    if call.index == tool_use_index {
571                        let func = call.function?;
572                        let mut arguments = func.arguments?;
573                        if let Some(mut first_chunk) = first_chunk.take() {
574                            first_chunk.push_str(&arguments);
575                            arguments = first_chunk
576                        }
577                        Some(Ok(arguments))
578                    } else {
579                        None
580                    }
581                })
582            }),
583        };
584
585        async move { result }
586    }))
587}
588
589pub fn extract_text_from_events(
590    response: impl Stream<Item = Result<ResponseStreamEvent>>,
591) -> impl Stream<Item = Result<String>> {
592    response.filter_map(|response| async move {
593        match response {
594            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
595            Err(error) => Some(Err(error)),
596        }
597    })
598}