open_ai.rs

 1use crate::{Embedding, EmbeddingProvider, TextToEmbed};
 2use anyhow::Result;
 3use futures::{FutureExt, future::BoxFuture};
 4use http_client::HttpClient;
 5pub use open_ai::OpenAiEmbeddingModel;
 6use std::sync::Arc;
 7
 8pub struct OpenAiEmbeddingProvider {
 9    client: Arc<dyn HttpClient>,
10    model: OpenAiEmbeddingModel,
11    api_url: String,
12    api_key: String,
13}
14
15impl OpenAiEmbeddingProvider {
16    pub fn new(
17        client: Arc<dyn HttpClient>,
18        model: OpenAiEmbeddingModel,
19        api_url: String,
20        api_key: String,
21    ) -> Self {
22        Self {
23            client,
24            model,
25            api_url,
26            api_key,
27        }
28    }
29}
30
31impl EmbeddingProvider for OpenAiEmbeddingProvider {
32    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
33        let embed = open_ai::embed(
34            self.client.as_ref(),
35            &self.api_url,
36            &self.api_key,
37            self.model,
38            texts.iter().map(|to_embed| to_embed.text),
39        );
40        async move {
41            let response = embed.await?;
42            Ok(response
43                .data
44                .into_iter()
45                .map(|data| Embedding::new(data.embedding))
46                .collect())
47        }
48        .boxed()
49    }
50
51    fn batch_size(&self) -> usize {
52        // From https://platform.openai.com/docs/api-reference/embeddings/create
53        2048
54    }
55}