open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use futures::{
  5    io::BufReader,
  6    stream::{self, BoxStream},
  7    AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
  8};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use isahc::config::Configurable;
 11use serde::{Deserialize, Serialize};
 12use serde_json::Value;
 13use std::{
 14    convert::TryFrom,
 15    future::{self, Future},
 16    pin::Pin,
 17    time::Duration,
 18};
 19use strum::EnumIter;
 20
 21pub use supported_countries::*;
 22
 23pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 24
 25fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 26    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 27}
 28
 29#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 30#[serde(rename_all = "lowercase")]
 31pub enum Role {
 32    User,
 33    Assistant,
 34    System,
 35    Tool,
 36}
 37
 38impl TryFrom<String> for Role {
 39    type Error = anyhow::Error;
 40
 41    fn try_from(value: String) -> Result<Self> {
 42        match value.as_str() {
 43            "user" => Ok(Self::User),
 44            "assistant" => Ok(Self::Assistant),
 45            "system" => Ok(Self::System),
 46            "tool" => Ok(Self::Tool),
 47            _ => Err(anyhow!("invalid role '{value}'")),
 48        }
 49    }
 50}
 51
 52impl From<Role> for String {
 53    fn from(val: Role) -> Self {
 54        match val {
 55            Role::User => "user".to_owned(),
 56            Role::Assistant => "assistant".to_owned(),
 57            Role::System => "system".to_owned(),
 58            Role::Tool => "tool".to_owned(),
 59        }
 60    }
 61}
 62
 63#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 64#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 65pub enum Model {
 66    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
 67    ThreePointFiveTurbo,
 68    #[serde(rename = "gpt-4", alias = "gpt-4")]
 69    Four,
 70    #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
 71    FourTurbo,
 72    #[serde(rename = "gpt-4o", alias = "gpt-4o")]
 73    #[default]
 74    FourOmni,
 75    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
 76    FourOmniMini,
 77    #[serde(rename = "o1-preview", alias = "o1-preview")]
 78    O1Preview,
 79    #[serde(rename = "o1-mini", alias = "o1-mini")]
 80    O1Mini,
 81
 82    #[serde(rename = "custom")]
 83    Custom {
 84        name: String,
 85        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 86        display_name: Option<String>,
 87        max_tokens: usize,
 88        max_output_tokens: Option<u32>,
 89        max_completion_tokens: Option<u32>,
 90    },
 91}
 92
 93impl Model {
 94    pub fn from_id(id: &str) -> Result<Self> {
 95        match id {
 96            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 97            "gpt-4" => Ok(Self::Four),
 98            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
 99            "gpt-4o" => Ok(Self::FourOmni),
100            "gpt-4o-mini" => Ok(Self::FourOmniMini),
101            "o1-preview" => Ok(Self::O1Preview),
102            "o1-mini" => Ok(Self::O1Mini),
103            _ => Err(anyhow!("invalid model id")),
104        }
105    }
106
107    pub fn id(&self) -> &str {
108        match self {
109            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
110            Self::Four => "gpt-4",
111            Self::FourTurbo => "gpt-4-turbo",
112            Self::FourOmni => "gpt-4o",
113            Self::FourOmniMini => "gpt-4o-mini",
114            Self::O1Preview => "o1-preview",
115            Self::O1Mini => "o1-mini",
116            Self::Custom { name, .. } => name,
117        }
118    }
119
120    pub fn display_name(&self) -> &str {
121        match self {
122            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
123            Self::Four => "gpt-4",
124            Self::FourTurbo => "gpt-4-turbo",
125            Self::FourOmni => "gpt-4o",
126            Self::FourOmniMini => "gpt-4o-mini",
127            Self::O1Preview => "o1-preview",
128            Self::O1Mini => "o1-mini",
129            Self::Custom {
130                name, display_name, ..
131            } => display_name.as_ref().unwrap_or(name),
132        }
133    }
134
135    pub fn max_token_count(&self) -> usize {
136        match self {
137            Self::ThreePointFiveTurbo => 16385,
138            Self::Four => 8192,
139            Self::FourTurbo => 128000,
140            Self::FourOmni => 128000,
141            Self::FourOmniMini => 128000,
142            Self::O1Preview => 128000,
143            Self::O1Mini => 128000,
144            Self::Custom { max_tokens, .. } => *max_tokens,
145        }
146    }
147
148    pub fn max_output_tokens(&self) -> Option<u32> {
149        match self {
150            Self::Custom {
151                max_output_tokens, ..
152            } => *max_output_tokens,
153            _ => None,
154        }
155    }
156}
157
158#[derive(Debug, Serialize, Deserialize)]
159pub struct Request {
160    pub model: String,
161    pub messages: Vec<RequestMessage>,
162    pub stream: bool,
163    #[serde(default, skip_serializing_if = "Option::is_none")]
164    pub max_tokens: Option<u32>,
165    #[serde(default, skip_serializing_if = "Vec::is_empty")]
166    pub stop: Vec<String>,
167    pub temperature: f32,
168    #[serde(default, skip_serializing_if = "Option::is_none")]
169    pub tool_choice: Option<ToolChoice>,
170    #[serde(default, skip_serializing_if = "Vec::is_empty")]
171    pub tools: Vec<ToolDefinition>,
172}
173
174#[derive(Debug, Serialize, Deserialize)]
175#[serde(untagged)]
176pub enum ToolChoice {
177    Auto,
178    Required,
179    None,
180    Other(ToolDefinition),
181}
182
183#[derive(Clone, Deserialize, Serialize, Debug)]
184#[serde(tag = "type", rename_all = "snake_case")]
185pub enum ToolDefinition {
186    #[allow(dead_code)]
187    Function { function: FunctionDefinition },
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize)]
191pub struct FunctionDefinition {
192    pub name: String,
193    pub description: Option<String>,
194    pub parameters: Option<Value>,
195}
196
197#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
198#[serde(tag = "role", rename_all = "lowercase")]
199pub enum RequestMessage {
200    Assistant {
201        content: Option<String>,
202        #[serde(default, skip_serializing_if = "Vec::is_empty")]
203        tool_calls: Vec<ToolCall>,
204    },
205    User {
206        content: String,
207    },
208    System {
209        content: String,
210    },
211    Tool {
212        content: String,
213        tool_call_id: String,
214    },
215}
216
217#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
218pub struct ToolCall {
219    pub id: String,
220    #[serde(flatten)]
221    pub content: ToolCallContent,
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225#[serde(tag = "type", rename_all = "lowercase")]
226pub enum ToolCallContent {
227    Function { function: FunctionContent },
228}
229
230#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
231pub struct FunctionContent {
232    pub name: String,
233    pub arguments: String,
234}
235
236#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
237pub struct ResponseMessageDelta {
238    pub role: Option<Role>,
239    pub content: Option<String>,
240    #[serde(default, skip_serializing_if = "is_none_or_empty")]
241    pub tool_calls: Option<Vec<ToolCallChunk>>,
242}
243
244#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
245pub struct ToolCallChunk {
246    pub index: usize,
247    pub id: Option<String>,
248
249    // There is also an optional `type` field that would determine if a
250    // function is there. Sometimes this streams in with the `function` before
251    // it streams in the `type`
252    pub function: Option<FunctionChunk>,
253}
254
255#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
256pub struct FunctionChunk {
257    pub name: Option<String>,
258    pub arguments: Option<String>,
259}
260
261#[derive(Serialize, Deserialize, Debug)]
262pub struct Usage {
263    pub prompt_tokens: u32,
264    pub completion_tokens: u32,
265    pub total_tokens: u32,
266}
267
268#[derive(Serialize, Deserialize, Debug)]
269pub struct ChoiceDelta {
270    pub index: u32,
271    pub delta: ResponseMessageDelta,
272    pub finish_reason: Option<String>,
273}
274
275#[derive(Serialize, Deserialize, Debug)]
276#[serde(untagged)]
277pub enum ResponseStreamResult {
278    Ok(ResponseStreamEvent),
279    Err { error: String },
280}
281
282#[derive(Serialize, Deserialize, Debug)]
283pub struct ResponseStreamEvent {
284    pub created: u32,
285    pub model: String,
286    pub choices: Vec<ChoiceDelta>,
287    pub usage: Option<Usage>,
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291pub struct Response {
292    pub id: String,
293    pub object: String,
294    pub created: u64,
295    pub model: String,
296    pub choices: Vec<Choice>,
297    pub usage: Usage,
298}
299
300#[derive(Serialize, Deserialize, Debug)]
301pub struct Choice {
302    pub index: u32,
303    pub message: RequestMessage,
304    pub finish_reason: Option<String>,
305}
306
307pub async fn complete(
308    client: &dyn HttpClient,
309    api_url: &str,
310    api_key: &str,
311    request: Request,
312    low_speed_timeout: Option<Duration>,
313) -> Result<Response> {
314    let uri = format!("{api_url}/chat/completions");
315    let mut request_builder = HttpRequest::builder()
316        .method(Method::POST)
317        .uri(uri)
318        .header("Content-Type", "application/json")
319        .header("Authorization", format!("Bearer {}", api_key));
320    if let Some(low_speed_timeout) = low_speed_timeout {
321        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
322    };
323
324    let mut request_body = request;
325    request_body.stream = false;
326
327    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
328    let mut response = client.send(request).await?;
329
330    if response.status().is_success() {
331        let mut body = String::new();
332        response.body_mut().read_to_string(&mut body).await?;
333        let response: Response = serde_json::from_str(&body)?;
334        Ok(response)
335    } else {
336        let mut body = String::new();
337        response.body_mut().read_to_string(&mut body).await?;
338
339        #[derive(Deserialize)]
340        struct OpenAiResponse {
341            error: OpenAiError,
342        }
343
344        #[derive(Deserialize)]
345        struct OpenAiError {
346            message: String,
347        }
348
349        match serde_json::from_str::<OpenAiResponse>(&body) {
350            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
351                "Failed to connect to OpenAI API: {}",
352                response.error.message,
353            )),
354
355            _ => Err(anyhow!(
356                "Failed to connect to OpenAI API: {} {}",
357                response.status(),
358                body,
359            )),
360        }
361    }
362}
363
364fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
365    ResponseStreamEvent {
366        created: response.created as u32,
367        model: response.model,
368        choices: response
369            .choices
370            .into_iter()
371            .map(|choice| ChoiceDelta {
372                index: choice.index,
373                delta: ResponseMessageDelta {
374                    role: Some(match choice.message {
375                        RequestMessage::Assistant { .. } => Role::Assistant,
376                        RequestMessage::User { .. } => Role::User,
377                        RequestMessage::System { .. } => Role::System,
378                        RequestMessage::Tool { .. } => Role::Tool,
379                    }),
380                    content: match choice.message {
381                        RequestMessage::Assistant { content, .. } => content,
382                        RequestMessage::User { content } => Some(content),
383                        RequestMessage::System { content } => Some(content),
384                        RequestMessage::Tool { content, .. } => Some(content),
385                    },
386                    tool_calls: None,
387                },
388                finish_reason: choice.finish_reason,
389            })
390            .collect(),
391        usage: Some(response.usage),
392    }
393}
394
395pub async fn stream_completion(
396    client: &dyn HttpClient,
397    api_url: &str,
398    api_key: &str,
399    request: Request,
400    low_speed_timeout: Option<Duration>,
401) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
402    if request.model == "o1-preview" || request.model == "o1-mini" {
403        let response = complete(client, api_url, api_key, request, low_speed_timeout).await;
404        let response_stream_event = response.map(adapt_response_to_stream);
405        return Ok(stream::once(future::ready(response_stream_event)).boxed());
406    }
407
408    let uri = format!("{api_url}/chat/completions");
409    let mut request_builder = HttpRequest::builder()
410        .method(Method::POST)
411        .uri(uri)
412        .header("Content-Type", "application/json")
413        .header("Authorization", format!("Bearer {}", api_key));
414
415    if let Some(low_speed_timeout) = low_speed_timeout {
416        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
417    };
418
419    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
420    let mut response = client.send(request).await?;
421    if response.status().is_success() {
422        let reader = BufReader::new(response.into_body());
423        Ok(reader
424            .lines()
425            .filter_map(|line| async move {
426                match line {
427                    Ok(line) => {
428                        let line = line.strip_prefix("data: ")?;
429                        if line == "[DONE]" {
430                            None
431                        } else {
432                            match serde_json::from_str(line) {
433                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
434                                Ok(ResponseStreamResult::Err { error }) => {
435                                    Some(Err(anyhow!(error)))
436                                }
437                                Err(error) => Some(Err(anyhow!(error))),
438                            }
439                        }
440                    }
441                    Err(error) => Some(Err(anyhow!(error))),
442                }
443            })
444            .boxed())
445    } else {
446        let mut body = String::new();
447        response.body_mut().read_to_string(&mut body).await?;
448
449        #[derive(Deserialize)]
450        struct OpenAiResponse {
451            error: OpenAiError,
452        }
453
454        #[derive(Deserialize)]
455        struct OpenAiError {
456            message: String,
457        }
458
459        match serde_json::from_str::<OpenAiResponse>(&body) {
460            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
461                "Failed to connect to OpenAI API: {}",
462                response.error.message,
463            )),
464
465            _ => Err(anyhow!(
466                "Failed to connect to OpenAI API: {} {}",
467                response.status(),
468                body,
469            )),
470        }
471    }
472}
473
474#[derive(Copy, Clone, Serialize, Deserialize)]
475pub enum OpenAiEmbeddingModel {
476    #[serde(rename = "text-embedding-3-small")]
477    TextEmbedding3Small,
478    #[serde(rename = "text-embedding-3-large")]
479    TextEmbedding3Large,
480}
481
482#[derive(Serialize)]
483struct OpenAiEmbeddingRequest<'a> {
484    model: OpenAiEmbeddingModel,
485    input: Vec<&'a str>,
486}
487
488#[derive(Deserialize)]
489pub struct OpenAiEmbeddingResponse {
490    pub data: Vec<OpenAiEmbedding>,
491}
492
493#[derive(Deserialize)]
494pub struct OpenAiEmbedding {
495    pub embedding: Vec<f32>,
496}
497
498pub fn embed<'a>(
499    client: &dyn HttpClient,
500    api_url: &str,
501    api_key: &str,
502    model: OpenAiEmbeddingModel,
503    texts: impl IntoIterator<Item = &'a str>,
504) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
505    let uri = format!("{api_url}/embeddings");
506
507    let request = OpenAiEmbeddingRequest {
508        model,
509        input: texts.into_iter().collect(),
510    };
511    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
512    let request = HttpRequest::builder()
513        .method(Method::POST)
514        .uri(uri)
515        .header("Content-Type", "application/json")
516        .header("Authorization", format!("Bearer {}", api_key))
517        .body(body)
518        .map(|request| client.send(request));
519
520    async move {
521        let mut response = request?.await?;
522        let mut body = String::new();
523        response.body_mut().read_to_string(&mut body).await?;
524
525        if response.status().is_success() {
526            let response: OpenAiEmbeddingResponse =
527                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
528            Ok(response)
529        } else {
530            Err(anyhow!(
531                "error during embedding, status: {:?}, body: {:?}",
532                response.status(),
533                body
534            ))
535        }
536    }
537}
538
539pub async fn extract_tool_args_from_events(
540    tool_name: String,
541    mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
542) -> Result<impl Send + Stream<Item = Result<String>>> {
543    let mut tool_use_index = None;
544    let mut first_chunk = None;
545    while let Some(event) = events.next().await {
546        let call = event?.choices.into_iter().find_map(|choice| {
547            choice.delta.tool_calls?.into_iter().find_map(|call| {
548                if call.function.as_ref()?.name.as_deref()? == tool_name {
549                    Some(call)
550                } else {
551                    None
552                }
553            })
554        });
555        if let Some(call) = call {
556            tool_use_index = Some(call.index);
557            first_chunk = call.function.and_then(|func| func.arguments);
558            break;
559        }
560    }
561
562    let Some(tool_use_index) = tool_use_index else {
563        return Err(anyhow!("tool not used"));
564    };
565
566    Ok(events.filter_map(move |event| {
567        let result = match event {
568            Err(error) => Some(Err(error)),
569            Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
570                choice.delta.tool_calls?.into_iter().find_map(|call| {
571                    if call.index == tool_use_index {
572                        let func = call.function?;
573                        let mut arguments = func.arguments?;
574                        if let Some(mut first_chunk) = first_chunk.take() {
575                            first_chunk.push_str(&arguments);
576                            arguments = first_chunk
577                        }
578                        Some(Ok(arguments))
579                    } else {
580                        None
581                    }
582                })
583            }),
584        };
585
586        async move { result }
587    }))
588}
589
590pub fn extract_text_from_events(
591    response: impl Stream<Item = Result<ResponseStreamEvent>>,
592) -> impl Stream<Item = Result<String>> {
593    response.filter_map(|response| async move {
594        match response {
595            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
596            Err(error) => Some(Err(error)),
597        }
598    })
599}