cloud.rs

 1use crate::{Embedding, EmbeddingProvider, TextToEmbed};
 2use anyhow::{Context as _, Result, anyhow};
 3use client::{Client, proto};
 4use collections::HashMap;
 5use futures::{FutureExt, future::BoxFuture};
 6use std::sync::Arc;
 7
 8pub struct CloudEmbeddingProvider {
 9    model: String,
10    client: Arc<Client>,
11}
12
13impl CloudEmbeddingProvider {
14    pub fn new(client: Arc<Client>) -> Self {
15        Self {
16            model: "openai/text-embedding-3-small".into(),
17            client,
18        }
19    }
20}
21
22impl EmbeddingProvider for CloudEmbeddingProvider {
23    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
24        // First, fetch any embeddings that are cached based on the requested texts' digests
25        // Then compute any embeddings that are missing.
26        async move {
27            if !self.client.status().borrow().is_connected() {
28                return Err(anyhow!("sign in required"));
29            }
30
31            let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
32                model: self.model.clone(),
33                digests: texts
34                    .iter()
35                    .map(|to_embed| to_embed.digest.to_vec())
36                    .collect(),
37            });
38            let mut embeddings = cached_embeddings
39                .await
40                .context("failed to fetch cached embeddings via cloud model")?
41                .embeddings
42                .into_iter()
43                .map(|embedding| {
44                    let digest: [u8; 32] = embedding
45                        .digest
46                        .try_into()
47                        .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
48                    Ok((digest, embedding.dimensions))
49                })
50                .collect::<Result<HashMap<_, _>>>()?;
51
52            let compute_embeddings_request = proto::ComputeEmbeddings {
53                model: self.model.clone(),
54                texts: texts
55                    .iter()
56                    .filter_map(|to_embed| {
57                        if embeddings.contains_key(&to_embed.digest) {
58                            None
59                        } else {
60                            Some(to_embed.text.to_string())
61                        }
62                    })
63                    .collect(),
64            };
65            if !compute_embeddings_request.texts.is_empty() {
66                let missing_embeddings = self.client.request(compute_embeddings_request).await?;
67                for embedding in missing_embeddings.embeddings {
68                    let digest: [u8; 32] = embedding
69                        .digest
70                        .try_into()
71                        .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
72                    embeddings.insert(digest, embedding.dimensions);
73                }
74            }
75
76            texts
77                .iter()
78                .map(|to_embed| {
79                    let embedding =
80                        embeddings.get(&to_embed.digest).cloned().with_context(|| {
81                            format!("server did not return an embedding for {:?}", to_embed)
82                        })?;
83                    Ok(Embedding::new(embedding))
84                })
85                .collect()
86        }
87        .boxed()
88    }
89
90    fn batch_size(&self) -> usize {
91        2048
92    }
93}