embedding_queue.rs

  1use crate::{parsing::Span, JobHandle};
  2use ai::embedding::EmbeddingProvider;
  3use gpui::executor::Background;
  4use parking_lot::Mutex;
  5use smol::channel;
  6use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
  7
  8#[derive(Clone)]
  9pub struct FileToEmbed {
 10    pub worktree_id: i64,
 11    pub path: Arc<Path>,
 12    pub mtime: SystemTime,
 13    pub spans: Vec<Span>,
 14    pub job_handle: JobHandle,
 15}
 16
 17impl std::fmt::Debug for FileToEmbed {
 18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 19        f.debug_struct("FileToEmbed")
 20            .field("worktree_id", &self.worktree_id)
 21            .field("path", &self.path)
 22            .field("mtime", &self.mtime)
 23            .field("spans", &self.spans)
 24            .finish_non_exhaustive()
 25    }
 26}
 27
 28impl PartialEq for FileToEmbed {
 29    fn eq(&self, other: &Self) -> bool {
 30        self.worktree_id == other.worktree_id
 31            && self.path == other.path
 32            && self.mtime == other.mtime
 33            && self.spans == other.spans
 34    }
 35}
 36
 37pub struct EmbeddingQueue {
 38    embedding_provider: Arc<dyn EmbeddingProvider>,
 39    pending_batch: Vec<FileFragmentToEmbed>,
 40    executor: Arc<Background>,
 41    pending_batch_token_count: usize,
 42    finished_files_tx: channel::Sender<FileToEmbed>,
 43    finished_files_rx: channel::Receiver<FileToEmbed>,
 44    api_key: Option<String>,
 45}
 46
 47#[derive(Clone)]
 48pub struct FileFragmentToEmbed {
 49    file: Arc<Mutex<FileToEmbed>>,
 50    span_range: Range<usize>,
 51}
 52
 53impl EmbeddingQueue {
 54    pub fn new(
 55        embedding_provider: Arc<dyn EmbeddingProvider>,
 56        executor: Arc<Background>,
 57        api_key: Option<String>,
 58    ) -> Self {
 59        let (finished_files_tx, finished_files_rx) = channel::unbounded();
 60        Self {
 61            embedding_provider,
 62            executor,
 63            pending_batch: Vec::new(),
 64            pending_batch_token_count: 0,
 65            finished_files_tx,
 66            finished_files_rx,
 67            api_key,
 68        }
 69    }
 70
 71    pub fn set_api_key(&mut self, api_key: Option<String>) {
 72        self.api_key = api_key
 73    }
 74
 75    pub fn push(&mut self, file: FileToEmbed) {
 76        if file.spans.is_empty() {
 77            self.finished_files_tx.try_send(file).unwrap();
 78            return;
 79        }
 80
 81        let file = Arc::new(Mutex::new(file));
 82
 83        self.pending_batch.push(FileFragmentToEmbed {
 84            file: file.clone(),
 85            span_range: 0..0,
 86        });
 87
 88        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
 89        for (ix, span) in file.lock().spans.iter().enumerate() {
 90            let span_token_count = if span.embedding.is_none() {
 91                span.token_count
 92            } else {
 93                0
 94            };
 95
 96            let next_token_count = self.pending_batch_token_count + span_token_count;
 97            if next_token_count > self.embedding_provider.max_tokens_per_batch() {
 98                let range_end = fragment_range.end;
 99                self.flush();
100                self.pending_batch.push(FileFragmentToEmbed {
101                    file: file.clone(),
102                    span_range: range_end..range_end,
103                });
104                fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
105            }
106
107            fragment_range.end = ix + 1;
108            self.pending_batch_token_count += span_token_count;
109        }
110    }
111
112    pub fn flush(&mut self) {
113        let batch = mem::take(&mut self.pending_batch);
114        self.pending_batch_token_count = 0;
115        if batch.is_empty() {
116            return;
117        }
118
119        let finished_files_tx = self.finished_files_tx.clone();
120        let embedding_provider = self.embedding_provider.clone();
121        let api_key = self.api_key.clone();
122
123        self.executor
124            .spawn(async move {
125                let mut spans = Vec::new();
126                for fragment in &batch {
127                    let file = fragment.file.lock();
128                    spans.extend(
129                        file.spans[fragment.span_range.clone()]
130                            .iter()
131                            .filter(|d| d.embedding.is_none())
132                            .map(|d| d.content.clone()),
133                    );
134                }
135
136                // If spans is 0, just send the fragment to the finished files if its the last one.
137                if spans.is_empty() {
138                    for fragment in batch.clone() {
139                        if let Some(file) = Arc::into_inner(fragment.file) {
140                            finished_files_tx.try_send(file.into_inner()).unwrap();
141                        }
142                    }
143                    return;
144                };
145
146                match embedding_provider.embed_batch(spans, api_key).await {
147                    Ok(embeddings) => {
148                        let mut embeddings = embeddings.into_iter();
149                        for fragment in batch {
150                            for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
151                                .iter_mut()
152                                .filter(|d| d.embedding.is_none())
153                            {
154                                if let Some(embedding) = embeddings.next() {
155                                    span.embedding = Some(embedding);
156                                } else {
157                                    log::error!("number of embeddings != number of documents");
158                                }
159                            }
160
161                            if let Some(file) = Arc::into_inner(fragment.file) {
162                                finished_files_tx.try_send(file.into_inner()).unwrap();
163                            }
164                        }
165                    }
166                    Err(error) => {
167                        log::error!("{:?}", error);
168                    }
169                }
170            })
171            .detach();
172    }
173
174    pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
175        self.finished_files_rx.clone()
176    }
177}