1use crate::{Embedding, EmbeddingProvider, TextToEmbed};
2use anyhow::{anyhow, Context, Result};
3use client::{proto, Client};
4use collections::HashMap;
5use futures::{future::BoxFuture, FutureExt};
6use std::sync::Arc;
7
8pub struct CloudEmbeddingProvider {
9 model: String,
10 client: Arc<Client>,
11}
12
13impl CloudEmbeddingProvider {
14 pub fn new(client: Arc<Client>) -> Self {
15 Self {
16 model: "openai/text-embedding-3-small".into(),
17 client,
18 }
19 }
20}
21
22impl EmbeddingProvider for CloudEmbeddingProvider {
23 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
24 // First, fetch any embeddings that are cached based on the requested texts' digests
25 // Then compute any embeddings that are missing.
26 async move {
27 let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
28 model: self.model.clone(),
29 digests: texts
30 .iter()
31 .map(|to_embed| to_embed.digest.to_vec())
32 .collect(),
33 });
34 let mut embeddings = cached_embeddings
35 .await
36 .context("failed to fetch cached embeddings via cloud model")?
37 .embeddings
38 .into_iter()
39 .map(|embedding| {
40 let digest: [u8; 32] = embedding
41 .digest
42 .try_into()
43 .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
44 Ok((digest, embedding.dimensions))
45 })
46 .collect::<Result<HashMap<_, _>>>()?;
47
48 let compute_embeddings_request = proto::ComputeEmbeddings {
49 model: self.model.clone(),
50 texts: texts
51 .iter()
52 .filter_map(|to_embed| {
53 if embeddings.contains_key(&to_embed.digest) {
54 None
55 } else {
56 Some(to_embed.text.to_string())
57 }
58 })
59 .collect(),
60 };
61 if !compute_embeddings_request.texts.is_empty() {
62 let missing_embeddings = self.client.request(compute_embeddings_request).await?;
63 for embedding in missing_embeddings.embeddings {
64 let digest: [u8; 32] = embedding
65 .digest
66 .try_into()
67 .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
68 embeddings.insert(digest, embedding.dimensions);
69 }
70 }
71
72 texts
73 .iter()
74 .map(|to_embed| {
75 let embedding =
76 embeddings.get(&to_embed.digest).cloned().with_context(|| {
77 format!("server did not return an embedding for {:?}", to_embed)
78 })?;
79 Ok(Embedding::new(embedding))
80 })
81 .collect()
82 }
83 .boxed()
84 }
85
86 fn batch_size(&self) -> usize {
87 2048
88 }
89}