open_ai.rs

  1pub mod batches;
  2pub mod completion;
  3pub mod responses;
  4
  5use anyhow::{Context as _, Result, anyhow};
  6use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  7use http_client::{
  8    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
  9    http::{HeaderMap, HeaderValue},
 10};
 11pub use language_model_core::ReasoningEffort;
 12use serde::{Deserialize, Serialize};
 13use serde_json::Value;
 14use std::{convert::TryFrom, future::Future};
 15use strum::EnumIter;
 16use thiserror::Error;
 17
 18pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 19
 20fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 21    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 22}
 23
 24#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 25#[serde(rename_all = "lowercase")]
 26pub enum Role {
 27    User,
 28    Assistant,
 29    System,
 30    Tool,
 31}
 32
 33impl TryFrom<String> for Role {
 34    type Error = anyhow::Error;
 35
 36    fn try_from(value: String) -> Result<Self> {
 37        match value.as_str() {
 38            "user" => Ok(Self::User),
 39            "assistant" => Ok(Self::Assistant),
 40            "system" => Ok(Self::System),
 41            "tool" => Ok(Self::Tool),
 42            _ => anyhow::bail!("invalid role '{value}'"),
 43        }
 44    }
 45}
 46
 47impl From<Role> for String {
 48    fn from(val: Role) -> Self {
 49        match val {
 50            Role::User => "user".to_owned(),
 51            Role::Assistant => "assistant".to_owned(),
 52            Role::System => "system".to_owned(),
 53            Role::Tool => "tool".to_owned(),
 54        }
 55    }
 56}
 57
 58#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 59#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 60pub enum Model {
 61    #[serde(rename = "gpt-3.5-turbo")]
 62    ThreePointFiveTurbo,
 63    #[serde(rename = "gpt-4")]
 64    Four,
 65    #[serde(rename = "gpt-4-turbo")]
 66    FourTurbo,
 67    #[serde(rename = "gpt-4o-mini")]
 68    FourOmniMini,
 69    #[serde(rename = "gpt-4.1-nano")]
 70    FourPointOneNano,
 71    #[serde(rename = "o1")]
 72    O1,
 73    #[serde(rename = "o3-mini")]
 74    O3Mini,
 75    #[serde(rename = "o3")]
 76    O3,
 77    #[serde(rename = "o4-mini")]
 78    O4Mini,
 79    #[serde(rename = "gpt-5")]
 80    Five,
 81    #[serde(rename = "gpt-5-codex")]
 82    FiveCodex,
 83    #[serde(rename = "gpt-5-mini")]
 84    #[default]
 85    FiveMini,
 86    #[serde(rename = "gpt-5-nano")]
 87    FiveNano,
 88    #[serde(rename = "gpt-5.1")]
 89    FivePointOne,
 90    #[serde(rename = "gpt-5.2")]
 91    FivePointTwo,
 92    #[serde(rename = "gpt-5.2-codex")]
 93    FivePointTwoCodex,
 94    #[serde(rename = "gpt-5.3-codex")]
 95    FivePointThreeCodex,
 96    #[serde(rename = "gpt-5.4")]
 97    FivePointFour,
 98    #[serde(rename = "gpt-5.4-pro")]
 99    FivePointFourPro,
100    #[serde(rename = "custom")]
101    Custom {
102        name: String,
103        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
104        display_name: Option<String>,
105        max_tokens: u64,
106        max_output_tokens: Option<u64>,
107        max_completion_tokens: Option<u64>,
108        reasoning_effort: Option<ReasoningEffort>,
109        #[serde(default = "default_supports_chat_completions")]
110        supports_chat_completions: bool,
111    },
112}
113
114const fn default_supports_chat_completions() -> bool {
115    true
116}
117
118impl Model {
119    pub fn default_fast() -> Self {
120        Self::FiveMini
121    }
122
123    pub fn from_id(id: &str) -> Result<Self> {
124        match id {
125            "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
126            "gpt-4" => Ok(Self::Four),
127            "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
128            "gpt-4o-mini" => Ok(Self::FourOmniMini),
129            "gpt-4.1-nano" => Ok(Self::FourPointOneNano),
130            "o1" => Ok(Self::O1),
131            "o3-mini" => Ok(Self::O3Mini),
132            "o3" => Ok(Self::O3),
133            "o4-mini" => Ok(Self::O4Mini),
134            "gpt-5" => Ok(Self::Five),
135            "gpt-5-codex" => Ok(Self::FiveCodex),
136            "gpt-5-mini" => Ok(Self::FiveMini),
137            "gpt-5-nano" => Ok(Self::FiveNano),
138            "gpt-5.1" => Ok(Self::FivePointOne),
139            "gpt-5.2" => Ok(Self::FivePointTwo),
140            "gpt-5.2-codex" => Ok(Self::FivePointTwoCodex),
141            "gpt-5.3-codex" => Ok(Self::FivePointThreeCodex),
142            "gpt-5.4" => Ok(Self::FivePointFour),
143            "gpt-5.4-pro" => Ok(Self::FivePointFourPro),
144            invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
145        }
146    }
147
148    pub fn id(&self) -> &str {
149        match self {
150            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
151            Self::Four => "gpt-4",
152            Self::FourTurbo => "gpt-4-turbo",
153            Self::FourOmniMini => "gpt-4o-mini",
154            Self::FourPointOneNano => "gpt-4.1-nano",
155            Self::O1 => "o1",
156            Self::O3Mini => "o3-mini",
157            Self::O3 => "o3",
158            Self::O4Mini => "o4-mini",
159            Self::Five => "gpt-5",
160            Self::FiveCodex => "gpt-5-codex",
161            Self::FiveMini => "gpt-5-mini",
162            Self::FiveNano => "gpt-5-nano",
163            Self::FivePointOne => "gpt-5.1",
164            Self::FivePointTwo => "gpt-5.2",
165            Self::FivePointTwoCodex => "gpt-5.2-codex",
166            Self::FivePointThreeCodex => "gpt-5.3-codex",
167            Self::FivePointFour => "gpt-5.4",
168            Self::FivePointFourPro => "gpt-5.4-pro",
169            Self::Custom { name, .. } => name,
170        }
171    }
172
173    pub fn display_name(&self) -> &str {
174        match self {
175            Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
176            Self::Four => "gpt-4",
177            Self::FourTurbo => "gpt-4-turbo",
178            Self::FourOmniMini => "gpt-4o-mini",
179            Self::FourPointOneNano => "gpt-4.1-nano",
180            Self::O1 => "o1",
181            Self::O3Mini => "o3-mini",
182            Self::O3 => "o3",
183            Self::O4Mini => "o4-mini",
184            Self::Five => "gpt-5",
185            Self::FiveCodex => "gpt-5-codex",
186            Self::FiveMini => "gpt-5-mini",
187            Self::FiveNano => "gpt-5-nano",
188            Self::FivePointOne => "gpt-5.1",
189            Self::FivePointTwo => "gpt-5.2",
190            Self::FivePointTwoCodex => "gpt-5.2-codex",
191            Self::FivePointThreeCodex => "gpt-5.3-codex",
192            Self::FivePointFour => "gpt-5.4",
193            Self::FivePointFourPro => "gpt-5.4-pro",
194            Self::Custom { display_name, .. } => display_name.as_deref().unwrap_or(&self.id()),
195        }
196    }
197
198    pub fn max_token_count(&self) -> u64 {
199        match self {
200            Self::ThreePointFiveTurbo => 16_385,
201            Self::Four => 8_192,
202            Self::FourTurbo => 128_000,
203            Self::FourOmniMini => 128_000,
204            Self::FourPointOneNano => 1_047_576,
205            Self::O1 => 200_000,
206            Self::O3Mini => 200_000,
207            Self::O3 => 200_000,
208            Self::O4Mini => 200_000,
209            Self::Five => 272_000,
210            Self::FiveCodex => 272_000,
211            Self::FiveMini => 400_000,
212            Self::FiveNano => 400_000,
213            Self::FivePointOne => 400_000,
214            Self::FivePointTwo => 400_000,
215            Self::FivePointTwoCodex => 400_000,
216            Self::FivePointThreeCodex => 400_000,
217            Self::FivePointFour => 1_050_000,
218            Self::FivePointFourPro => 1_050_000,
219            Self::Custom { max_tokens, .. } => *max_tokens,
220        }
221    }
222
223    pub fn max_output_tokens(&self) -> Option<u64> {
224        match self {
225            Self::Custom {
226                max_output_tokens, ..
227            } => *max_output_tokens,
228            Self::ThreePointFiveTurbo => Some(4_096),
229            Self::Four => Some(8_192),
230            Self::FourTurbo => Some(4_096),
231            Self::FourOmniMini => Some(16_384),
232            Self::FourPointOneNano => Some(32_768),
233            Self::O1 => Some(100_000),
234            Self::O3Mini => Some(100_000),
235            Self::O3 => Some(100_000),
236            Self::O4Mini => Some(100_000),
237            Self::Five => Some(128_000),
238            Self::FiveCodex => Some(128_000),
239            Self::FiveMini => Some(128_000),
240            Self::FiveNano => Some(128_000),
241            Self::FivePointOne => Some(128_000),
242            Self::FivePointTwo => Some(128_000),
243            Self::FivePointTwoCodex => Some(128_000),
244            Self::FivePointThreeCodex => Some(128_000),
245            Self::FivePointFour => Some(128_000),
246            Self::FivePointFourPro => Some(128_000),
247        }
248    }
249
250    pub fn reasoning_effort(&self) -> Option<ReasoningEffort> {
251        match self {
252            Self::Custom {
253                reasoning_effort, ..
254            } => reasoning_effort.to_owned(),
255            Self::O4Mini | Self::FivePointThreeCodex | Self::FivePointFourPro => {
256                Some(ReasoningEffort::Medium)
257            }
258            _ => None,
259        }
260    }
261
262    pub fn supports_chat_completions(&self) -> bool {
263        match self {
264            Self::Custom {
265                supports_chat_completions,
266                ..
267            } => *supports_chat_completions,
268            Self::O4Mini
269            | Self::FiveCodex
270            | Self::FivePointTwoCodex
271            | Self::FivePointThreeCodex
272            | Self::FivePointFourPro => false,
273            _ => true,
274        }
275    }
276
277    /// Returns whether the given model supports the `parallel_tool_calls` parameter.
278    ///
279    /// If the model does not support the parameter, do not pass it up, or the API will return an error.
280    pub fn supports_parallel_tool_calls(&self) -> bool {
281        match self {
282            Self::ThreePointFiveTurbo
283            | Self::Four
284            | Self::FourTurbo
285            | Self::FourOmniMini
286            | Self::FourPointOneNano
287            | Self::Five
288            | Self::FiveCodex
289            | Self::FiveMini
290            | Self::FivePointOne
291            | Self::FivePointTwo
292            | Self::FivePointTwoCodex
293            | Self::FivePointThreeCodex
294            | Self::FivePointFour
295            | Self::FivePointFourPro
296            | Self::FiveNano => true,
297            Self::O1 | Self::O3 | Self::O3Mini | Self::O4Mini | Model::Custom { .. } => false,
298        }
299    }
300
301    /// Returns whether the given model supports the `prompt_cache_key` parameter.
302    ///
303    /// If the model does not support the parameter, do not pass it up.
304    pub fn supports_prompt_cache_key(&self) -> bool {
305        true
306    }
307}
308
309#[derive(Debug, Serialize, Deserialize)]
310pub struct StreamOptions {
311    pub include_usage: bool,
312}
313
314impl Default for StreamOptions {
315    fn default() -> Self {
316        Self {
317            include_usage: true,
318        }
319    }
320}
321
322#[derive(Debug, Serialize, Deserialize)]
323pub struct Request {
324    pub model: String,
325    pub messages: Vec<RequestMessage>,
326    pub stream: bool,
327    #[serde(default, skip_serializing_if = "Option::is_none")]
328    pub stream_options: Option<StreamOptions>,
329    #[serde(default, skip_serializing_if = "Option::is_none")]
330    pub max_completion_tokens: Option<u64>,
331    #[serde(default, skip_serializing_if = "Vec::is_empty")]
332    pub stop: Vec<String>,
333    #[serde(default, skip_serializing_if = "Option::is_none")]
334    pub temperature: Option<f32>,
335    #[serde(default, skip_serializing_if = "Option::is_none")]
336    pub tool_choice: Option<ToolChoice>,
337    /// Whether to enable parallel function calling during tool use.
338    #[serde(default, skip_serializing_if = "Option::is_none")]
339    pub parallel_tool_calls: Option<bool>,
340    #[serde(default, skip_serializing_if = "Vec::is_empty")]
341    pub tools: Vec<ToolDefinition>,
342    #[serde(default, skip_serializing_if = "Option::is_none")]
343    pub prompt_cache_key: Option<String>,
344    #[serde(default, skip_serializing_if = "Option::is_none")]
345    pub reasoning_effort: Option<ReasoningEffort>,
346}
347
348#[derive(Debug, Serialize, Deserialize)]
349#[serde(rename_all = "lowercase")]
350pub enum ToolChoice {
351    Auto,
352    Required,
353    None,
354    #[serde(untagged)]
355    Other(ToolDefinition),
356}
357
358#[derive(Clone, Deserialize, Serialize, Debug)]
359#[serde(tag = "type", rename_all = "snake_case")]
360pub enum ToolDefinition {
361    #[allow(dead_code)]
362    Function { function: FunctionDefinition },
363}
364
365#[derive(Clone, Debug, Serialize, Deserialize)]
366pub struct FunctionDefinition {
367    pub name: String,
368    pub description: Option<String>,
369    pub parameters: Option<Value>,
370}
371
372#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
373#[serde(tag = "role", rename_all = "lowercase")]
374pub enum RequestMessage {
375    Assistant {
376        content: Option<MessageContent>,
377        #[serde(default, skip_serializing_if = "Vec::is_empty")]
378        tool_calls: Vec<ToolCall>,
379    },
380    User {
381        content: MessageContent,
382    },
383    System {
384        content: MessageContent,
385    },
386    Tool {
387        content: MessageContent,
388        tool_call_id: String,
389    },
390}
391
392#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
393#[serde(untagged)]
394pub enum MessageContent {
395    Plain(String),
396    Multipart(Vec<MessagePart>),
397}
398
399impl MessageContent {
400    pub fn empty() -> Self {
401        MessageContent::Multipart(vec![])
402    }
403
404    pub fn push_part(&mut self, part: MessagePart) {
405        match self {
406            MessageContent::Plain(text) => {
407                *self =
408                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
409            }
410            MessageContent::Multipart(parts) if parts.is_empty() => match part {
411                MessagePart::Text { text } => *self = MessageContent::Plain(text),
412                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
413            },
414            MessageContent::Multipart(parts) => parts.push(part),
415        }
416    }
417}
418
419impl From<Vec<MessagePart>> for MessageContent {
420    fn from(mut parts: Vec<MessagePart>) -> Self {
421        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
422            MessageContent::Plain(std::mem::take(text))
423        } else {
424            MessageContent::Multipart(parts)
425        }
426    }
427}
428
429#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
430#[serde(tag = "type")]
431pub enum MessagePart {
432    #[serde(rename = "text")]
433    Text { text: String },
434    #[serde(rename = "image_url")]
435    Image { image_url: ImageUrl },
436}
437
438#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
439pub struct ImageUrl {
440    pub url: String,
441    #[serde(skip_serializing_if = "Option::is_none")]
442    pub detail: Option<String>,
443}
444
445#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
446pub struct ToolCall {
447    pub id: String,
448    #[serde(flatten)]
449    pub content: ToolCallContent,
450}
451
452#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
453#[serde(tag = "type", rename_all = "lowercase")]
454pub enum ToolCallContent {
455    Function { function: FunctionContent },
456}
457
458#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
459pub struct FunctionContent {
460    pub name: String,
461    pub arguments: String,
462}
463
464#[derive(Clone, Serialize, Deserialize, Debug)]
465pub struct Response {
466    pub id: String,
467    pub object: String,
468    pub created: u64,
469    pub model: String,
470    pub choices: Vec<Choice>,
471    pub usage: Usage,
472}
473
474#[derive(Clone, Serialize, Deserialize, Debug)]
475pub struct Choice {
476    pub index: u32,
477    pub message: RequestMessage,
478    pub finish_reason: Option<String>,
479}
480
481#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
482pub struct ResponseMessageDelta {
483    pub role: Option<Role>,
484    pub content: Option<String>,
485    #[serde(default, skip_serializing_if = "is_none_or_empty")]
486    pub tool_calls: Option<Vec<ToolCallChunk>>,
487    #[serde(default, skip_serializing_if = "is_none_or_empty")]
488    pub reasoning_content: Option<String>,
489}
490
491#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
492pub struct ToolCallChunk {
493    pub index: usize,
494    pub id: Option<String>,
495
496    // There is also an optional `type` field that would determine if a
497    // function is there. Sometimes this streams in with the `function` before
498    // it streams in the `type`
499    pub function: Option<FunctionChunk>,
500}
501
502#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
503pub struct FunctionChunk {
504    pub name: Option<String>,
505    pub arguments: Option<String>,
506}
507
508#[derive(Clone, Serialize, Deserialize, Debug)]
509pub struct Usage {
510    pub prompt_tokens: u64,
511    pub completion_tokens: u64,
512    pub total_tokens: u64,
513}
514
515#[derive(Serialize, Deserialize, Debug)]
516pub struct ChoiceDelta {
517    pub index: u32,
518    pub delta: Option<ResponseMessageDelta>,
519    pub finish_reason: Option<String>,
520}
521
522#[derive(Error, Debug)]
523pub enum RequestError {
524    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
525    HttpResponseError {
526        provider: String,
527        status_code: StatusCode,
528        body: String,
529        headers: HeaderMap<HeaderValue>,
530    },
531    #[error(transparent)]
532    Other(#[from] anyhow::Error),
533}
534
535#[derive(Serialize, Deserialize, Debug)]
536pub struct ResponseStreamError {
537    message: String,
538}
539
540#[derive(Serialize, Deserialize, Debug)]
541#[serde(untagged)]
542pub enum ResponseStreamResult {
543    Ok(ResponseStreamEvent),
544    Err { error: ResponseStreamError },
545}
546
547#[derive(Serialize, Deserialize, Debug)]
548pub struct ResponseStreamEvent {
549    pub choices: Vec<ChoiceDelta>,
550    pub usage: Option<Usage>,
551}
552
553pub async fn non_streaming_completion(
554    client: &dyn HttpClient,
555    api_url: &str,
556    api_key: &str,
557    request: Request,
558) -> Result<Response, RequestError> {
559    let uri = format!("{api_url}/chat/completions");
560    let request_builder = HttpRequest::builder()
561        .method(Method::POST)
562        .uri(uri)
563        .header("Content-Type", "application/json")
564        .header("Authorization", format!("Bearer {}", api_key.trim()));
565
566    let request = request_builder
567        .body(AsyncBody::from(
568            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
569        ))
570        .map_err(|e| RequestError::Other(e.into()))?;
571
572    let mut response = client.send(request).await?;
573    if response.status().is_success() {
574        let mut body = String::new();
575        response
576            .body_mut()
577            .read_to_string(&mut body)
578            .await
579            .map_err(|e| RequestError::Other(e.into()))?;
580
581        serde_json::from_str(&body).map_err(|e| RequestError::Other(e.into()))
582    } else {
583        let mut body = String::new();
584        response
585            .body_mut()
586            .read_to_string(&mut body)
587            .await
588            .map_err(|e| RequestError::Other(e.into()))?;
589
590        Err(RequestError::HttpResponseError {
591            provider: "openai".to_owned(),
592            status_code: response.status(),
593            body,
594            headers: response.headers().clone(),
595        })
596    }
597}
598
599pub async fn stream_completion(
600    client: &dyn HttpClient,
601    provider_name: &str,
602    api_url: &str,
603    api_key: &str,
604    request: Request,
605) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
606    let uri = format!("{api_url}/chat/completions");
607    let request_builder = HttpRequest::builder()
608        .method(Method::POST)
609        .uri(uri)
610        .header("Content-Type", "application/json")
611        .header("Authorization", format!("Bearer {}", api_key.trim()));
612
613    let request = request_builder
614        .body(AsyncBody::from(
615            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
616        ))
617        .map_err(|e| RequestError::Other(e.into()))?;
618
619    let mut response = client.send(request).await?;
620    if response.status().is_success() {
621        let reader = BufReader::new(response.into_body());
622        Ok(reader
623            .lines()
624            .filter_map(|line| async move {
625                match line {
626                    Ok(line) => {
627                        let line = line.strip_prefix("data: ").or_else(|| line.strip_prefix("data:"))?;
628                        if line == "[DONE]" {
629                            None
630                        } else {
631                            match serde_json::from_str(line) {
632                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
633                                Ok(ResponseStreamResult::Err { error }) => {
634                                    Some(Err(anyhow!(error.message)))
635                                }
636                                Err(error) => {
637                                    log::error!(
638                                        "Failed to parse OpenAI response into ResponseStreamResult: `{}`\n\
639                                        Response: `{}`",
640                                        error,
641                                        line,
642                                    );
643                                    Some(Err(anyhow!(error)))
644                                }
645                            }
646                        }
647                    }
648                    Err(error) => Some(Err(anyhow!(error))),
649                }
650            })
651            .boxed())
652    } else {
653        let mut body = String::new();
654        response
655            .body_mut()
656            .read_to_string(&mut body)
657            .await
658            .map_err(|e| RequestError::Other(e.into()))?;
659
660        Err(RequestError::HttpResponseError {
661            provider: provider_name.to_owned(),
662            status_code: response.status(),
663            body,
664            headers: response.headers().clone(),
665        })
666    }
667}
668
669#[derive(Copy, Clone, Serialize, Deserialize)]
670pub enum OpenAiEmbeddingModel {
671    #[serde(rename = "text-embedding-3-small")]
672    TextEmbedding3Small,
673    #[serde(rename = "text-embedding-3-large")]
674    TextEmbedding3Large,
675}
676
677#[derive(Serialize)]
678struct OpenAiEmbeddingRequest<'a> {
679    model: OpenAiEmbeddingModel,
680    input: Vec<&'a str>,
681}
682
683#[derive(Deserialize)]
684pub struct OpenAiEmbeddingResponse {
685    pub data: Vec<OpenAiEmbedding>,
686}
687
688#[derive(Deserialize)]
689pub struct OpenAiEmbedding {
690    pub embedding: Vec<f32>,
691}
692
693pub fn embed<'a>(
694    client: &dyn HttpClient,
695    api_url: &str,
696    api_key: &str,
697    model: OpenAiEmbeddingModel,
698    texts: impl IntoIterator<Item = &'a str>,
699) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
700    let uri = format!("{api_url}/embeddings");
701
702    let request = OpenAiEmbeddingRequest {
703        model,
704        input: texts.into_iter().collect(),
705    };
706    let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
707    let request = HttpRequest::builder()
708        .method(Method::POST)
709        .uri(uri)
710        .header("Content-Type", "application/json")
711        .header("Authorization", format!("Bearer {}", api_key.trim()))
712        .body(body)
713        .map(|request| client.send(request));
714
715    async move {
716        let mut response = request?.await?;
717        let mut body = String::new();
718        response.body_mut().read_to_string(&mut body).await?;
719
720        anyhow::ensure!(
721            response.status().is_success(),
722            "error during embedding, status: {:?}, body: {:?}",
723            response.status(),
724            body
725        );
726        let response: OpenAiEmbeddingResponse =
727            serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
728        Ok(response)
729    }
730}
731
732// -- Conversions to `language_model_core` types --
733
734impl From<RequestError> for language_model_core::LanguageModelCompletionError {
735    fn from(error: RequestError) -> Self {
736        match error {
737            RequestError::HttpResponseError {
738                provider,
739                status_code,
740                body,
741                headers,
742            } => {
743                let retry_after = headers
744                    .get(http_client::http::header::RETRY_AFTER)
745                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
746                    .map(std::time::Duration::from_secs);
747
748                Self::from_http_status(provider.into(), status_code, body, retry_after)
749            }
750            RequestError::Other(e) => Self::Other(e),
751        }
752    }
753}