1use crate::{Embedding, EmbeddingProvider, TextToEmbed};
2use anyhow::Result;
3use futures::{FutureExt, future::BoxFuture};
4use http_client::HttpClient;
5pub use open_ai::OpenAiEmbeddingModel;
6use std::sync::Arc;
7
8pub struct OpenAiEmbeddingProvider {
9 client: Arc<dyn HttpClient>,
10 model: OpenAiEmbeddingModel,
11 api_url: String,
12 api_key: String,
13}
14
15impl OpenAiEmbeddingProvider {
16 pub fn new(
17 client: Arc<dyn HttpClient>,
18 model: OpenAiEmbeddingModel,
19 api_url: String,
20 api_key: String,
21 ) -> Self {
22 Self {
23 client,
24 model,
25 api_url,
26 api_key,
27 }
28 }
29}
30
31impl EmbeddingProvider for OpenAiEmbeddingProvider {
32 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
33 let embed = open_ai::embed(
34 self.client.as_ref(),
35 &self.api_url,
36 &self.api_key,
37 self.model,
38 texts.iter().map(|to_embed| to_embed.text),
39 );
40 async move {
41 let response = embed.await?;
42 Ok(response
43 .data
44 .into_iter()
45 .map(|data| Embedding::new(data.embedding))
46 .collect())
47 }
48 .boxed()
49 }
50
51 fn batch_size(&self) -> usize {
52 // From https://platform.openai.com/docs/api-reference/embeddings/create
53 2048
54 }
55}