open_ai.rs

  1mod supported_countries;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use futures::{
  5    AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
  6    io::BufReader,
  7    stream::{self, BoxStream},
  8};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use serde_json::Value;
 12use std::{
 13    convert::TryFrom,
 14    future::{self, Future},
 15    pin::Pin,
 16};
 17use strum::EnumIter;
 18
 19pub use supported_countries::*;
 20
 21pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 22
 23fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 24    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 25}
 26
 27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 28#[serde(rename_all = "lowercase")]
 29pub enum Role {
 30    User,
 31    Assistant,
 32    System,
 33    Tool,
 34}
 35
 36impl TryFrom<String> for Role {
 37    type Error = anyhow::Error;
 38
 39    fn try_from(value: String) -> Result<Self> {
 40        match value.as_str() {
 41            "user" => Ok(Self::User),
 42            "assistant" => Ok(Self::Assistant),
 43            "system" => Ok(Self::System),
 44            "tool" => Ok(Self::Tool),
 45            _ => Err(anyhow!("invalid role '{value}'")),
 46        }
 47    }
 48}
 49
 50impl From<Role> for String {
 51    fn from(val: Role) -> Self {
 52        match val {
 53            Role::User => "user".to_owned(),
 54            Role::Assistant => "assistant".to_owned(),
 55            Role::System => "system".to_owned(),
 56            Role::Tool => "tool".to_owned(),
 57        }
 58    }
 59}
 60
 61#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 62#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 63pub enum Model {
 64    #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo")]
 65    ThreePointFiveTurbo,
 66    #[serde(rename = "gpt-4", alias = "gpt-4")]
 67    Four,
 68    #[serde(rename = "gpt-4-turbo", alias = "gpt-4-turbo")]
 69    FourTurbo,
 70    #[serde(rename = "gpt-4o", alias = "gpt-4o")]
 71    #[default]
 72    FourOmni,
 73    #[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini")]
 74    FourOmniMini,
 75    #[serde(rename = "o1", alias = "o1")]
 76    O1,
 77    #[serde(rename = "o1-preview", alias = "o1-preview")]
 78    O1Preview,
 79    #[serde(rename = "o1-mini", alias = "o1-mini")]
 80    O1Mini,
 81    #[serde(rename = "o3-mini", alias = "o3-mini")]
 82    O3Mini,
 83
 84    #[serde(rename = "custom")]
 85    Custom {
 86        name: String,
 87        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 88        display_name: Option<String>,
 89        max_tokens: usize,
 90        max_output_tokens: Option<u32>,
 91        max_completion_tokens: Option<u32>,
 92    },
 93}
 94
 95impl Model {
 96    pub fn from_id(id: &str) -> Result<Self> {
 97        match id {
 98            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
 99            "gpt-4" => Ok(Self::Four),
100            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
101            "gpt-4o" => Ok(Self::FourOmni),
102            "gpt-4o-mini" => Ok(Self::FourOmniMini),
103            "o1" => Ok(Self::O1),
104            "o1-preview" => Ok(Self::O1Preview),
105            "o1-mini" => Ok(Self::O1Mini),
106            "o3-mini" => Ok(Self::O3Mini),
107            _ => Err(anyhow!("invalid model id")),
108        }
109    }
110
111    pub fn id(&self) -> &str {
112        match self {
113            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
114            Self::Four => "gpt-4",
115            Self::FourTurbo => "gpt-4-turbo",
116            Self::FourOmni => "gpt-4o",
117            Self::FourOmniMini => "gpt-4o-mini",
118            Self::O1 => "o1",
119            Self::O1Preview => "o1-preview",
120            Self::O1Mini => "o1-mini",
121            Self::O3Mini => "o3-mini",
122            Self::Custom { name, .. } => name,
123        }
124    }
125
126    pub fn display_name(&self) -> &str {
127        match self {
128            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
129            Self::Four => "gpt-4",
130            Self::FourTurbo => "gpt-4-turbo",
131            Self::FourOmni => "gpt-4o",
132            Self::FourOmniMini => "gpt-4o-mini",
133            Self::O1 => "o1",
134            Self::O1Preview => "o1-preview",
135            Self::O1Mini => "o1-mini",
136            Self::O3Mini => "o3-mini",
137            Self::Custom {
138                name, display_name, ..
139            } => display_name.as_ref().unwrap_or(name),
140        }
141    }
142
143    pub fn max_token_count(&self) -> usize {
144        match self {
145            Self::ThreePointFiveTurbo => 16_385,
146            Self::Four => 8_192,
147            Self::FourTurbo => 128_000,
148            Self::FourOmni => 128_000,
149            Self::FourOmniMini => 128_000,
150            Self::O1 => 200_000,
151            Self::O1Preview => 128_000,
152            Self::O1Mini => 128_000,
153            Self::O3Mini => 200_000,
154            Self::Custom { max_tokens, .. } => *max_tokens,
155        }
156    }
157
158    pub fn max_output_tokens(&self) -> Option<u32> {
159        match self {
160            Self::Custom {
161                max_output_tokens, ..
162            } => *max_output_tokens,
163            _ => None,
164        }
165    }
166}
167
168#[derive(Debug, Serialize, Deserialize)]
169pub struct Request {
170    pub model: String,
171    pub messages: Vec<RequestMessage>,
172    pub stream: bool,
173    #[serde(default, skip_serializing_if = "Option::is_none")]
174    pub max_tokens: Option<u32>,
175    #[serde(default, skip_serializing_if = "Vec::is_empty")]
176    pub stop: Vec<String>,
177    pub temperature: f32,
178    #[serde(default, skip_serializing_if = "Option::is_none")]
179    pub tool_choice: Option<ToolChoice>,
180    #[serde(default, skip_serializing_if = "Vec::is_empty")]
181    pub tools: Vec<ToolDefinition>,
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185pub struct CompletionRequest {
186    pub model: String,
187    pub prompt: String,
188    pub max_tokens: u32,
189    pub temperature: f32,
190    #[serde(default, skip_serializing_if = "Option::is_none")]
191    pub prediction: Option<Prediction>,
192    #[serde(default, skip_serializing_if = "Option::is_none")]
193    pub rewrite_speculation: Option<bool>,
194}
195
196#[derive(Clone, Deserialize, Serialize, Debug)]
197#[serde(tag = "type", rename_all = "snake_case")]
198pub enum Prediction {
199    Content { content: String },
200}
201
202#[derive(Debug, Serialize, Deserialize)]
203#[serde(untagged)]
204pub enum ToolChoice {
205    Auto,
206    Required,
207    None,
208    Other(ToolDefinition),
209}
210
211#[derive(Clone, Deserialize, Serialize, Debug)]
212#[serde(tag = "type", rename_all = "snake_case")]
213pub enum ToolDefinition {
214    #[allow(dead_code)]
215    Function { function: FunctionDefinition },
216}
217
218#[derive(Clone, Debug, Serialize, Deserialize)]
219pub struct FunctionDefinition {
220    pub name: String,
221    pub description: Option<String>,
222    pub parameters: Option<Value>,
223}
224
225#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
226#[serde(tag = "role", rename_all = "lowercase")]
227pub enum RequestMessage {
228    Assistant {
229        content: Option<String>,
230        #[serde(default, skip_serializing_if = "Vec::is_empty")]
231        tool_calls: Vec<ToolCall>,
232    },
233    User {
234        content: String,
235    },
236    System {
237        content: String,
238    },
239    Tool {
240        content: String,
241        tool_call_id: String,
242    },
243}
244
245#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
246pub struct ToolCall {
247    pub id: String,
248    #[serde(flatten)]
249    pub content: ToolCallContent,
250}
251
252#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
253#[serde(tag = "type", rename_all = "lowercase")]
254pub enum ToolCallContent {
255    Function { function: FunctionContent },
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259pub struct FunctionContent {
260    pub name: String,
261    pub arguments: String,
262}
263
264#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
265pub struct ResponseMessageDelta {
266    pub role: Option<Role>,
267    pub content: Option<String>,
268    #[serde(default, skip_serializing_if = "is_none_or_empty")]
269    pub tool_calls: Option<Vec<ToolCallChunk>>,
270}
271
272#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
273pub struct ToolCallChunk {
274    pub index: usize,
275    pub id: Option<String>,
276
277    // There is also an optional `type` field that would determine if a
278    // function is there. Sometimes this streams in with the `function` before
279    // it streams in the `type`
280    pub function: Option<FunctionChunk>,
281}
282
283#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
284pub struct FunctionChunk {
285    pub name: Option<String>,
286    pub arguments: Option<String>,
287}
288
289#[derive(Serialize, Deserialize, Debug)]
290pub struct Usage {
291    pub prompt_tokens: u32,
292    pub completion_tokens: u32,
293    pub total_tokens: u32,
294}
295
296#[derive(Serialize, Deserialize, Debug)]
297pub struct ChoiceDelta {
298    pub index: u32,
299    pub delta: ResponseMessageDelta,
300    pub finish_reason: Option<String>,
301}
302
303#[derive(Serialize, Deserialize, Debug)]
304#[serde(untagged)]
305pub enum ResponseStreamResult {
306    Ok(ResponseStreamEvent),
307    Err { error: String },
308}
309
310#[derive(Serialize, Deserialize, Debug)]
311pub struct ResponseStreamEvent {
312    pub created: u32,
313    pub model: String,
314    pub choices: Vec<ChoiceDelta>,
315    pub usage: Option<Usage>,
316}
317
318#[derive(Serialize, Deserialize, Debug)]
319pub struct CompletionResponse {
320    pub id: String,
321    pub object: String,
322    pub created: u64,
323    pub model: String,
324    pub choices: Vec<CompletionChoice>,
325    pub usage: Usage,
326}
327
328#[derive(Serialize, Deserialize, Debug)]
329pub struct CompletionChoice {
330    pub text: String,
331}
332
333#[derive(Serialize, Deserialize, Debug)]
334pub struct Response {
335    pub id: String,
336    pub object: String,
337    pub created: u64,
338    pub model: String,
339    pub choices: Vec<Choice>,
340    pub usage: Usage,
341}
342
343#[derive(Serialize, Deserialize, Debug)]
344pub struct Choice {
345    pub index: u32,
346    pub message: RequestMessage,
347    pub finish_reason: Option<String>,
348}
349
350pub async fn complete(
351    client: &dyn HttpClient,
352    api_url: &str,
353    api_key: &str,
354    request: Request,
355) -> Result<Response> {
356    let uri = format!("{api_url}/chat/completions");
357    let request_builder = HttpRequest::builder()
358        .method(Method::POST)
359        .uri(uri)
360        .header("Content-Type", "application/json")
361        .header("Authorization", format!("Bearer {}", api_key));
362
363    let mut request_body = request;
364    request_body.stream = false;
365
366    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
367    let mut response = client.send(request).await?;
368
369    if response.status().is_success() {
370        let mut body = String::new();
371        response.body_mut().read_to_string(&mut body).await?;
372        let response: Response = serde_json::from_str(&body)?;
373        Ok(response)
374    } else {
375        let mut body = String::new();
376        response.body_mut().read_to_string(&mut body).await?;
377
378        #[derive(Deserialize)]
379        struct OpenAiResponse {
380            error: OpenAiError,
381        }
382
383        #[derive(Deserialize)]
384        struct OpenAiError {
385            message: String,
386        }
387
388        match serde_json::from_str::<OpenAiResponse>(&body) {
389            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
390                "Failed to connect to OpenAI API: {}",
391                response.error.message,
392            )),
393
394            _ => Err(anyhow!(
395                "Failed to connect to OpenAI API: {} {}",
396                response.status(),
397                body,
398            )),
399        }
400    }
401}
402
403pub async fn complete_text(
404    client: &dyn HttpClient,
405    api_url: &str,
406    api_key: &str,
407    request: CompletionRequest,
408) -> Result<CompletionResponse> {
409    let uri = format!("{api_url}/completions");
410    let request_builder = HttpRequest::builder()
411        .method(Method::POST)
412        .uri(uri)
413        .header("Content-Type", "application/json")
414        .header("Authorization", format!("Bearer {}", api_key));
415
416    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
417    let mut response = client.send(request).await?;
418
419    if response.status().is_success() {
420        let mut body = String::new();
421        response.body_mut().read_to_string(&mut body).await?;
422        let response = serde_json::from_str(&body)?;
423        Ok(response)
424    } else {
425        let mut body = String::new();
426        response.body_mut().read_to_string(&mut body).await?;
427
428        #[derive(Deserialize)]
429        struct OpenAiResponse {
430            error: OpenAiError,
431        }
432
433        #[derive(Deserialize)]
434        struct OpenAiError {
435            message: String,
436        }
437
438        match serde_json::from_str::<OpenAiResponse>(&body) {
439            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
440                "Failed to connect to OpenAI API: {}",
441                response.error.message,
442            )),
443
444            _ => Err(anyhow!(
445                "Failed to connect to OpenAI API: {} {}",
446                response.status(),
447                body,
448            )),
449        }
450    }
451}
452
453fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
454    ResponseStreamEvent {
455        created: response.created as u32,
456        model: response.model,
457        choices: response
458            .choices
459            .into_iter()
460            .map(|choice| ChoiceDelta {
461                index: choice.index,
462                delta: ResponseMessageDelta {
463                    role: Some(match choice.message {
464                        RequestMessage::Assistant { .. } => Role::Assistant,
465                        RequestMessage::User { .. } => Role::User,
466                        RequestMessage::System { .. } => Role::System,
467                        RequestMessage::Tool { .. } => Role::Tool,
468                    }),
469                    content: match choice.message {
470                        RequestMessage::Assistant { content, .. } => content,
471                        RequestMessage::User { content } => Some(content),
472                        RequestMessage::System { content } => Some(content),
473                        RequestMessage::Tool { content, .. } => Some(content),
474                    },
475                    tool_calls: None,
476                },
477                finish_reason: choice.finish_reason,
478            })
479            .collect(),
480        usage: Some(response.usage),
481    }
482}
483
484pub async fn stream_completion(
485    client: &dyn HttpClient,
486    api_url: &str,
487    api_key: &str,
488    request: Request,
489) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
490    if request.model.starts_with("o1") {
491        let response = complete(client, api_url, api_key, request).await;
492        let response_stream_event = response.map(adapt_response_to_stream);
493        return Ok(stream::once(future::ready(response_stream_event)).boxed());
494    }
495
496    let uri = format!("{api_url}/chat/completions");
497    let request_builder = HttpRequest::builder()
498        .method(Method::POST)
499        .uri(uri)
500        .header("Content-Type", "application/json")
501        .header("Authorization", format!("Bearer {}", api_key));
502
503    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
504    let mut response = client.send(request).await?;
505    if response.status().is_success() {
506        let reader = BufReader::new(response.into_body());
507        Ok(reader
508            .lines()
509            .filter_map(|line| async move {
510                match line {
511                    Ok(line) => {
512                        let line = line.strip_prefix("data: ")?;
513                        if line == "[DONE]" {
514                            None
515                        } else {
516                            match serde_json::from_str(line) {
517                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
518                                Ok(ResponseStreamResult::Err { error }) => {
519                                    Some(Err(anyhow!(error)))
520                                }
521                                Err(error) => Some(Err(anyhow!(error))),
522                            }
523                        }
524                    }
525                    Err(error) => Some(Err(anyhow!(error))),
526                }
527            })
528            .boxed())
529    } else {
530        let mut body = String::new();
531        response.body_mut().read_to_string(&mut body).await?;
532
533        #[derive(Deserialize)]
534        struct OpenAiResponse {
535            error: OpenAiError,
536        }
537
538        #[derive(Deserialize)]
539        struct OpenAiError {
540            message: String,
541        }
542
543        match serde_json::from_str::<OpenAiResponse>(&body) {
544            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
545                "Failed to connect to OpenAI API: {}",
546                response.error.message,
547            )),
548
549            _ => Err(anyhow!(
550                "Failed to connect to OpenAI API: {} {}",
551                response.status(),
552                body,
553            )),
554        }
555    }
556}
557
558#[derive(Copy, Clone, Serialize, Deserialize)]
559pub enum OpenAiEmbeddingModel {
560    #[serde(rename = "text-embedding-3-small")]
561    TextEmbedding3Small,
562    #[serde(rename = "text-embedding-3-large")]
563    TextEmbedding3Large,
564}
565
566#[derive(Serialize)]
567struct OpenAiEmbeddingRequest<'a> {
568    model: OpenAiEmbeddingModel,
569    input: Vec<&'a str>,
570}
571
572#[derive(Deserialize)]
573pub struct OpenAiEmbeddingResponse {
574    pub data: Vec<OpenAiEmbedding>,
575}
576
577#[derive(Deserialize)]
578pub struct OpenAiEmbedding {
579    pub embedding: Vec<f32>,
580}
581
582pub fn embed<'a>(
583    client: &dyn HttpClient,
584    api_url: &str,
585    api_key: &str,
586    model: OpenAiEmbeddingModel,
587    texts: impl IntoIterator<Item = &'a str>,
588) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
589    let uri = format!("{api_url}/embeddings");
590
591    let request = OpenAiEmbeddingRequest {
592        model,
593        input: texts.into_iter().collect(),
594    };
595    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
596    let request = HttpRequest::builder()
597        .method(Method::POST)
598        .uri(uri)
599        .header("Content-Type", "application/json")
600        .header("Authorization", format!("Bearer {}", api_key))
601        .body(body)
602        .map(|request| client.send(request));
603
604    async move {
605        let mut response = request?.await?;
606        let mut body = String::new();
607        response.body_mut().read_to_string(&mut body).await?;
608
609        if response.status().is_success() {
610            let response: OpenAiEmbeddingResponse =
611                serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
612            Ok(response)
613        } else {
614            Err(anyhow!(
615                "error during embedding, status: {:?}, body: {:?}",
616                response.status(),
617                body
618            ))
619        }
620    }
621}
622
623pub async fn extract_tool_args_from_events(
624    tool_name: String,
625    mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
626) -> Result<impl Send + Stream<Item = Result<String>>> {
627    let mut tool_use_index = None;
628    let mut first_chunk = None;
629    while let Some(event) = events.next().await {
630        let call = event?.choices.into_iter().find_map(|choice| {
631            choice.delta.tool_calls?.into_iter().find_map(|call| {
632                if call.function.as_ref()?.name.as_deref()? == tool_name {
633                    Some(call)
634                } else {
635                    None
636                }
637            })
638        });
639        if let Some(call) = call {
640            tool_use_index = Some(call.index);
641            first_chunk = call.function.and_then(|func| func.arguments);
642            break;
643        }
644    }
645
646    let Some(tool_use_index) = tool_use_index else {
647        return Err(anyhow!("tool not used"));
648    };
649
650    Ok(events.filter_map(move |event| {
651        let result = match event {
652            Err(error) => Some(Err(error)),
653            Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
654                choice.delta.tool_calls?.into_iter().find_map(|call| {
655                    if call.index == tool_use_index {
656                        let func = call.function?;
657                        let mut arguments = func.arguments?;
658                        if let Some(mut first_chunk) = first_chunk.take() {
659                            first_chunk.push_str(&arguments);
660                            arguments = first_chunk
661                        }
662                        Some(Ok(arguments))
663                    } else {
664                        None
665                    }
666                })
667            }),
668        };
669
670        async move { result }
671    }))
672}
673
674pub fn extract_text_from_events(
675    response: impl Stream<Item = Result<ResponseStreamEvent>>,
676) -> impl Stream<Item = Result<String>> {
677    response.filter_map(|response| async move {
678        match response {
679            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
680            Err(error) => Some(Err(error)),
681        }
682    })
683}