open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use futures::{
  5    AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
  6    io::BufReader,
  7    stream::{self, BoxStream},
  8};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12use std::{
 13    convert::TryFrom,
 14    future::{self, Future},
 15};
 16use strum::EnumIter;
 17
 18pub use supported_countries::*;
 19
 20pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 21
 22fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 23    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 24}
 25
 26#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 27#[serde(rename_all = "lowercase")]
 28pub enum Role {
 29    User,
 30    Assistant,
 31    System,
 32    Tool,
 33}
 34
 35impl TryFrom<String> for Role {
 36    type Error = anyhow::Error;
 37
 38    fn try_from(value: String) -> Result<Self> {
 39        match value.as_str() {
 40            "user" => Ok(Self::User),
 41            "assistant" => Ok(Self::Assistant),
 42            "system" => Ok(Self::System),
 43            "tool" => Ok(Self::Tool),
 44            _ => Err(anyhow!("invalid role '{value}'")),
 45        }
 46    }
 47}
 48
 49impl From<Role> for String {
 50    fn from(val: Role) -> Self {
 51        match val {
 52            Role::User => "user".to_owned(),
 53            Role::Assistant => "assistant".to_owned(),
 54            Role::System => "system".to_owned(),
 55            Role::Tool => "tool".to_owned(),
 56        }
 57    }
 58}
 59
 60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 62pub enum Model {
 63    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
 64    ThreePointFiveTurbo,
 65    #[serde(rename = "gpt-4", alias = "gpt-4")]
 66    Four,
 67    #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
 68    FourTurbo,
 69    #[serde(rename = "gpt-4o", alias = "gpt-4o")]
 70    #[default]
 71    FourOmni,
 72    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
 73    FourOmniMini,
 74    #[serde(rename = "o1", alias = "o1")]
 75    O1,
 76    #[serde(rename = "o1-preview", alias = "o1-preview")]
 77    O1Preview,
 78    #[serde(rename = "o1-mini", alias = "o1-mini")]
 79    O1Mini,
 80    #[serde(rename = "o3-mini", alias = "o3-mini")]
 81    O3Mini,
 82
 83    #[serde(rename = "custom")]
 84    Custom {
 85        name: String,
 86        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 87        display_name: Option<String>,
 88        max_tokens: usize,
 89        max_output_tokens: Option<u32>,
 90        max_completion_tokens: Option<u32>,
 91    },
 92}
 93
 94impl Model {
 95    pub fn from_id(id: &str) -> Result<Self> {
 96        match id {
 97            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 98            "gpt-4" => Ok(Self::Four),
 99            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
100            "gpt-4o" => Ok(Self::FourOmni),
101            "gpt-4o-mini" => Ok(Self::FourOmniMini),
102            "o1" => Ok(Self::O1),
103            "o1-preview" => Ok(Self::O1Preview),
104            "o1-mini" => Ok(Self::O1Mini),
105            "o3-mini" => Ok(Self::O3Mini),
106            _ => Err(anyhow!("invalid model id")),
107        }
108    }
109
110    pub fn id(&self) -> &str {
111        match self {
112            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
113            Self::Four => "gpt-4",
114            Self::FourTurbo => "gpt-4-turbo",
115            Self::FourOmni => "gpt-4o",
116            Self::FourOmniMini => "gpt-4o-mini",
117            Self::O1 => "o1",
118            Self::O1Preview => "o1-preview",
119            Self::O1Mini => "o1-mini",
120            Self::O3Mini => "o3-mini",
121            Self::Custom { name, .. } => name,
122        }
123    }
124
125    pub fn display_name(&self) -> &str {
126        match self {
127            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
128            Self::Four => "gpt-4",
129            Self::FourTurbo => "gpt-4-turbo",
130            Self::FourOmni => "gpt-4o",
131            Self::FourOmniMini => "gpt-4o-mini",
132            Self::O1 => "o1",
133            Self::O1Preview => "o1-preview",
134            Self::O1Mini => "o1-mini",
135            Self::O3Mini => "o3-mini",
136            Self::Custom {
137                name, display_name, ..
138            } => display_name.as_ref().unwrap_or(name),
139        }
140    }
141
142    pub fn max_token_count(&self) -> usize {
143        match self {
144            Self::ThreePointFiveTurbo => 16_385,
145            Self::Four => 8_192,
146            Self::FourTurbo => 128_000,
147            Self::FourOmni => 128_000,
148            Self::FourOmniMini => 128_000,
149            Self::O1 => 200_000,
150            Self::O1Preview => 128_000,
151            Self::O1Mini => 128_000,
152            Self::O3Mini => 200_000,
153            Self::Custom { max_tokens, .. } => *max_tokens,
154        }
155    }
156
157    pub fn max_output_tokens(&self) -> Option<u32> {
158        match self {
159            Self::Custom {
160                max_output_tokens, ..
161            } => *max_output_tokens,
162            _ => None,
163        }
164    }
165}
166
167#[derive(Debug, Serialize, Deserialize)]
168pub struct Request {
169    pub model: String,
170    pub messages: Vec<RequestMessage>,
171    pub stream: bool,
172    #[serde(default, skip_serializing_if = "Option::is_none")]
173    pub max_tokens: Option<u32>,
174    #[serde(default, skip_serializing_if = "Vec::is_empty")]
175    pub stop: Vec<String>,
176    pub temperature: f32,
177    #[serde(default, skip_serializing_if = "Option::is_none")]
178    pub tool_choice: Option<ToolChoice>,
179    #[serde(default, skip_serializing_if = "Vec::is_empty")]
180    pub tools: Vec<ToolDefinition>,
181}
182
183#[derive(Debug, Serialize, Deserialize)]
184pub struct CompletionRequest {
185    pub model: String,
186    pub prompt: String,
187    pub max_tokens: u32,
188    pub temperature: f32,
189    #[serde(default, skip_serializing_if = "Option::is_none")]
190    pub prediction: Option<Prediction>,
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub rewrite_speculation: Option<bool>,
193}
194
195#[derive(Clone, Deserialize, Serialize, Debug)]
196#[serde(tag = "type", rename_all = "snake_case")]
197pub enum Prediction {
198    Content { content: String },
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202#[serde(untagged)]
203pub enum ToolChoice {
204    Auto,
205    Required,
206    None,
207    Other(ToolDefinition),
208}
209
210#[derive(Clone, Deserialize, Serialize, Debug)]
211#[serde(tag = "type", rename_all = "snake_case")]
212pub enum ToolDefinition {
213    #[allow(dead_code)]
214    Function { function: FunctionDefinition },
215}
216
217#[derive(Clone, Debug, Serialize, Deserialize)]
218pub struct FunctionDefinition {
219    pub name: String,
220    pub description: Option<String>,
221    pub parameters: Option<Value>,
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225#[serde(tag = "role", rename_all = "lowercase")]
226pub enum RequestMessage {
227    Assistant {
228        content: Option<String>,
229        #[serde(default, skip_serializing_if = "Vec::is_empty")]
230        tool_calls: Vec<ToolCall>,
231    },
232    User {
233        content: String,
234    },
235    System {
236        content: String,
237    },
238    Tool {
239        content: String,
240        tool_call_id: String,
241    },
242}
243
244#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
245pub struct ToolCall {
246    pub id: String,
247    #[serde(flatten)]
248    pub content: ToolCallContent,
249}
250
251#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
252#[serde(tag = "type", rename_all = "lowercase")]
253pub enum ToolCallContent {
254    Function { function: FunctionContent },
255}
256
257#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
258pub struct FunctionContent {
259    pub name: String,
260    pub arguments: String,
261}
262
263#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
264pub struct ResponseMessageDelta {
265    pub role: Option<Role>,
266    pub content: Option<String>,
267    #[serde(default, skip_serializing_if = "is_none_or_empty")]
268    pub tool_calls: Option<Vec<ToolCallChunk>>,
269}
270
271#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
272pub struct ToolCallChunk {
273    pub index: usize,
274    pub id: Option<String>,
275
276    // There is also an optional `type` field that would determine if a
277    // function is there. Sometimes this streams in with the `function` before
278    // it streams in the `type`
279    pub function: Option<FunctionChunk>,
280}
281
282#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
283pub struct FunctionChunk {
284    pub name: Option<String>,
285    pub arguments: Option<String>,
286}
287
288#[derive(Serialize, Deserialize, Debug)]
289pub struct Usage {
290    pub prompt_tokens: u32,
291    pub completion_tokens: u32,
292    pub total_tokens: u32,
293}
294
295#[derive(Serialize, Deserialize, Debug)]
296pub struct ChoiceDelta {
297    pub index: u32,
298    pub delta: ResponseMessageDelta,
299    pub finish_reason: Option<String>,
300}
301
302#[derive(Serialize, Deserialize, Debug)]
303#[serde(untagged)]
304pub enum ResponseStreamResult {
305    Ok(ResponseStreamEvent),
306    Err { error: String },
307}
308
309#[derive(Serialize, Deserialize, Debug)]
310pub struct ResponseStreamEvent {
311    pub created: u32,
312    pub model: String,
313    pub choices: Vec<ChoiceDelta>,
314    pub usage: Option<Usage>,
315}
316
317#[derive(Serialize, Deserialize, Debug)]
318pub struct CompletionResponse {
319    pub id: String,
320    pub object: String,
321    pub created: u64,
322    pub model: String,
323    pub choices: Vec<CompletionChoice>,
324    pub usage: Usage,
325}
326
327#[derive(Serialize, Deserialize, Debug)]
328pub struct CompletionChoice {
329    pub text: String,
330}
331
332#[derive(Serialize, Deserialize, Debug)]
333pub struct Response {
334    pub id: String,
335    pub object: String,
336    pub created: u64,
337    pub model: String,
338    pub choices: Vec<Choice>,
339    pub usage: Usage,
340}
341
342#[derive(Serialize, Deserialize, Debug)]
343pub struct Choice {
344    pub index: u32,
345    pub message: RequestMessage,
346    pub finish_reason: Option<String>,
347}
348
349pub async fn complete(
350    client: &dyn HttpClient,
351    api_url: &str,
352    api_key: &str,
353    request: Request,
354) -> Result<Response> {
355    let uri = format!("{api_url}/chat/completions");
356    let request_builder = HttpRequest::builder()
357        .method(Method::POST)
358        .uri(uri)
359        .header("Content-Type", "application/json")
360        .header("Authorization", format!("Bearer {}", api_key));
361
362    let mut request_body = request;
363    request_body.stream = false;
364
365    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
366    let mut response = client.send(request).await?;
367
368    if response.status().is_success() {
369        let mut body = String::new();
370        response.body_mut().read_to_string(&mut body).await?;
371        let response: Response = serde_json::from_str(&body)?;
372        Ok(response)
373    } else {
374        let mut body = String::new();
375        response.body_mut().read_to_string(&mut body).await?;
376
377        #[derive(Deserialize)]
378        struct OpenAiResponse {
379            error: OpenAiError,
380        }
381
382        #[derive(Deserialize)]
383        struct OpenAiError {
384            message: String,
385        }
386
387        match serde_json::from_str::<OpenAiResponse>(&body) {
388            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
389                "Failed to connect to OpenAI API: {}",
390                response.error.message,
391            )),
392
393            _ => Err(anyhow!(
394                "Failed to connect to OpenAI API: {} {}",
395                response.status(),
396                body,
397            )),
398        }
399    }
400}
401
402pub async fn complete_text(
403    client: &dyn HttpClient,
404    api_url: &str,
405    api_key: &str,
406    request: CompletionRequest,
407) -> Result<CompletionResponse> {
408    let uri = format!("{api_url}/completions");
409    let request_builder = HttpRequest::builder()
410        .method(Method::POST)
411        .uri(uri)
412        .header("Content-Type", "application/json")
413        .header("Authorization", format!("Bearer {}", api_key));
414
415    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
416    let mut response = client.send(request).await?;
417
418    if response.status().is_success() {
419        let mut body = String::new();
420        response.body_mut().read_to_string(&mut body).await?;
421        let response = serde_json::from_str(&body)?;
422        Ok(response)
423    } else {
424        let mut body = String::new();
425        response.body_mut().read_to_string(&mut body).await?;
426
427        #[derive(Deserialize)]
428        struct OpenAiResponse {
429            error: OpenAiError,
430        }
431
432        #[derive(Deserialize)]
433        struct OpenAiError {
434            message: String,
435        }
436
437        match serde_json::from_str::<OpenAiResponse>(&body) {
438            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
439                "Failed to connect to OpenAI API: {}",
440                response.error.message,
441            )),
442
443            _ => Err(anyhow!(
444                "Failed to connect to OpenAI API: {} {}",
445                response.status(),
446                body,
447            )),
448        }
449    }
450}
451
452fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
453    ResponseStreamEvent {
454        created: response.created as u32,
455        model: response.model,
456        choices: response
457            .choices
458            .into_iter()
459            .map(|choice| ChoiceDelta {
460                index: choice.index,
461                delta: ResponseMessageDelta {
462                    role: Some(match choice.message {
463                        RequestMessage::Assistant { .. } => Role::Assistant,
464                        RequestMessage::User { .. } => Role::User,
465                        RequestMessage::System { .. } => Role::System,
466                        RequestMessage::Tool { .. } => Role::Tool,
467                    }),
468                    content: match choice.message {
469                        RequestMessage::Assistant { content, .. } => content,
470                        RequestMessage::User { content } => Some(content),
471                        RequestMessage::System { content } => Some(content),
472                        RequestMessage::Tool { content, .. } => Some(content),
473                    },
474                    tool_calls: None,
475                },
476                finish_reason: choice.finish_reason,
477            })
478            .collect(),
479        usage: Some(response.usage),
480    }
481}
482
483pub async fn stream_completion(
484    client: &dyn HttpClient,
485    api_url: &str,
486    api_key: &str,
487    request: Request,
488) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
489    if request.model.starts_with("o1") {
490        let response = complete(client, api_url, api_key, request).await;
491        let response_stream_event = response.map(adapt_response_to_stream);
492        return Ok(stream::once(future::ready(response_stream_event)).boxed());
493    }
494
495    let uri = format!("{api_url}/chat/completions");
496    let request_builder = HttpRequest::builder()
497        .method(Method::POST)
498        .uri(uri)
499        .header("Content-Type", "application/json")
500        .header("Authorization", format!("Bearer {}", api_key));
501
502    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
503    let mut response = client.send(request).await?;
504    if response.status().is_success() {
505        let reader = BufReader::new(response.into_body());
506        Ok(reader
507            .lines()
508            .filter_map(|line| async move {
509                match line {
510                    Ok(line) => {
511                        let line = line.strip_prefix("data: ")?;
512                        if line == "[DONE]" {
513                            None
514                        } else {
515                            match serde_json::from_str(line) {
516                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
517                                Ok(ResponseStreamResult::Err { error }) => {
518                                    Some(Err(anyhow!(error)))
519                                }
520                                Err(error) => Some(Err(anyhow!(error))),
521                            }
522                        }
523                    }
524                    Err(error) => Some(Err(anyhow!(error))),
525                }
526            })
527            .boxed())
528    } else {
529        let mut body = String::new();
530        response.body_mut().read_to_string(&mut body).await?;
531
532        #[derive(Deserialize)]
533        struct OpenAiResponse {
534            error: OpenAiError,
535        }
536
537        #[derive(Deserialize)]
538        struct OpenAiError {
539            message: String,
540        }
541
542        match serde_json::from_str::<OpenAiResponse>(&body) {
543            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
544                "Failed to connect to OpenAI API: {}",
545                response.error.message,
546            )),
547
548            _ => Err(anyhow!(
549                "Failed to connect to OpenAI API: {} {}",
550                response.status(),
551                body,
552            )),
553        }
554    }
555}
556
557#[derive(Copy, Clone, Serialize, Deserialize)]
558pub enum OpenAiEmbeddingModel {
559    #[serde(rename = "text-embedding-3-small")]
560    TextEmbedding3Small,
561    #[serde(rename = "text-embedding-3-large")]
562    TextEmbedding3Large,
563}
564
565#[derive(Serialize)]
566struct OpenAiEmbeddingRequest<'a> {
567    model: OpenAiEmbeddingModel,
568    input: Vec<&'a str>,
569}
570
571#[derive(Deserialize)]
572pub struct OpenAiEmbeddingResponse {
573    pub data: Vec<OpenAiEmbedding>,
574}
575
576#[derive(Deserialize)]
577pub struct OpenAiEmbedding {
578    pub embedding: Vec<f32>,
579}
580
581pub fn embed<'a>(
582    client: &dyn HttpClient,
583    api_url: &str,
584    api_key: &str,
585    model: OpenAiEmbeddingModel,
586    texts: impl IntoIterator<Item = &'a str>,
587) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
588    let uri = format!("{api_url}/embeddings");
589
590    let request = OpenAiEmbeddingRequest {
591        model,
592        input: texts.into_iter().collect(),
593    };
594    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
595    let request = HttpRequest::builder()
596        .method(Method::POST)
597        .uri(uri)
598        .header("Content-Type", "application/json")
599        .header("Authorization", format!("Bearer {}", api_key))
600        .body(body)
601        .map(|request| client.send(request));
602
603    async move {
604        let mut response = request?.await?;
605        let mut body = String::new();
606        response.body_mut().read_to_string(&mut body).await?;
607
608        if response.status().is_success() {
609            let response: OpenAiEmbeddingResponse =
610                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
611            Ok(response)
612        } else {
613            Err(anyhow!(
614                "error during embedding, status: {:?}, body: {:?}",
615                response.status(),
616                body
617            ))
618        }
619    }
620}
621
622pub fn extract_text_from_events(
623    response: impl Stream<Item = Result<ResponseStreamEvent>>,
624) -> impl Stream<Item = Result<String>> {
625    response.filter_map(|response| async move {
626        match response {
627            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
628            Err(error) => Some(Err(error)),
629        }
630    })
631}