1use crate::{Embedding, EmbeddingProvider, TextToEmbed};
2use anyhow::{Context as _, Result, anyhow};
3use client::{Client, proto};
4use collections::HashMap;
5use futures::{FutureExt, future::BoxFuture};
6use std::sync::Arc;
7
8pub struct CloudEmbeddingProvider {
9 model: String,
10 client: Arc<Client>,
11}
12
13impl CloudEmbeddingProvider {
14 pub fn new(client: Arc<Client>) -> Self {
15 Self {
16 model: "openai/text-embedding-3-small".into(),
17 client,
18 }
19 }
20}
21
22impl EmbeddingProvider for CloudEmbeddingProvider {
23 fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
24 // First, fetch any embeddings that are cached based on the requested texts' digests
25 // Then compute any embeddings that are missing.
26 async move {
27 if !self.client.status().borrow().is_connected() {
28 return Err(anyhow!("sign in required"));
29 }
30
31 let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
32 model: self.model.clone(),
33 digests: texts
34 .iter()
35 .map(|to_embed| to_embed.digest.to_vec())
36 .collect(),
37 });
38 let mut embeddings = cached_embeddings
39 .await
40 .context("failed to fetch cached embeddings via cloud model")?
41 .embeddings
42 .into_iter()
43 .map(|embedding| {
44 let digest: [u8; 32] = embedding
45 .digest
46 .try_into()
47 .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
48 Ok((digest, embedding.dimensions))
49 })
50 .collect::<Result<HashMap<_, _>>>()?;
51
52 let compute_embeddings_request = proto::ComputeEmbeddings {
53 model: self.model.clone(),
54 texts: texts
55 .iter()
56 .filter_map(|to_embed| {
57 if embeddings.contains_key(&to_embed.digest) {
58 None
59 } else {
60 Some(to_embed.text.to_string())
61 }
62 })
63 .collect(),
64 };
65 if !compute_embeddings_request.texts.is_empty() {
66 let missing_embeddings = self.client.request(compute_embeddings_request).await?;
67 for embedding in missing_embeddings.embeddings {
68 let digest: [u8; 32] = embedding
69 .digest
70 .try_into()
71 .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
72 embeddings.insert(digest, embedding.dimensions);
73 }
74 }
75
76 texts
77 .iter()
78 .map(|to_embed| {
79 let embedding =
80 embeddings.get(&to_embed.digest).cloned().with_context(|| {
81 format!("server did not return an embedding for {:?}", to_embed)
82 })?;
83 Ok(Embedding::new(embedding))
84 })
85 .collect()
86 }
87 .boxed()
88 }
89
90 fn batch_size(&self) -> usize {
91 2048
92 }
93}