cloud.rs

 1use crate::{Embedding, EmbeddingProvider, TextToEmbed};
 2use anyhow::{anyhow, Context, Result};
 3use client::{proto, Client};
 4use collections::HashMap;
 5use futures::{future::BoxFuture, FutureExt};
 6use std::sync::Arc;
 7
 8pub struct CloudEmbeddingProvider {
 9    model: String,
10    client: Arc<Client>,
11}
12
13impl CloudEmbeddingProvider {
14    pub fn new(client: Arc<Client>) -> Self {
15        Self {
16            model: "openai/text-embedding-3-small".into(),
17            client,
18        }
19    }
20}
21
22impl EmbeddingProvider for CloudEmbeddingProvider {
23    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
24        // First, fetch any embeddings that are cached based on the requested texts' digests
25        // Then compute any embeddings that are missing.
26        async move {
27            let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
28                model: self.model.clone(),
29                digests: texts
30                    .iter()
31                    .map(|to_embed| to_embed.digest.to_vec())
32                    .collect(),
33            });
34            let mut embeddings = cached_embeddings
35                .await
36                .context("failed to fetch cached embeddings via cloud model")?
37                .embeddings
38                .into_iter()
39                .map(|embedding| {
40                    let digest: [u8; 32] = embedding
41                        .digest
42                        .try_into()
43                        .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
44                    Ok((digest, embedding.dimensions))
45                })
46                .collect::<Result<HashMap<_, _>>>()?;
47
48            let compute_embeddings_request = proto::ComputeEmbeddings {
49                model: self.model.clone(),
50                texts: texts
51                    .iter()
52                    .filter_map(|to_embed| {
53                        if embeddings.contains_key(&to_embed.digest) {
54                            None
55                        } else {
56                            Some(to_embed.text.to_string())
57                        }
58                    })
59                    .collect(),
60            };
61            if !compute_embeddings_request.texts.is_empty() {
62                let missing_embeddings = self.client.request(compute_embeddings_request).await?;
63                for embedding in missing_embeddings.embeddings {
64                    let digest: [u8; 32] = embedding
65                        .digest
66                        .try_into()
67                        .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
68                    embeddings.insert(digest, embedding.dimensions);
69                }
70            }
71
72            texts
73                .iter()
74                .map(|to_embed| {
75                    let dimensions = embeddings.remove(&to_embed.digest).with_context(|| {
76                        format!("server did not return an embedding for {:?}", to_embed)
77                    })?;
78                    Ok(Embedding::new(dimensions))
79                })
80                .collect()
81        }
82        .boxed()
83    }
84
85    fn batch_size(&self) -> usize {
86        2048
87    }
88}