embedding_queue.rs

  1use crate::{parsing::Span, JobHandle};
  2use ai::embedding::EmbeddingProvider;
  3use gpui::BackgroundExecutor;
  4use parking_lot::Mutex;
  5use smol::channel;
  6use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
  7
  8#[derive(Clone)]
  9pub struct FileToEmbed {
 10    pub worktree_id: i64,
 11    pub path: Arc<Path>,
 12    pub mtime: SystemTime,
 13    pub spans: Vec<Span>,
 14    pub job_handle: JobHandle,
 15}
 16
 17impl std::fmt::Debug for FileToEmbed {
 18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 19        f.debug_struct("FileToEmbed")
 20            .field("worktree_id", &self.worktree_id)
 21            .field("path", &self.path)
 22            .field("mtime", &self.mtime)
 23            .field("spans", &self.spans)
 24            .finish_non_exhaustive()
 25    }
 26}
 27
 28impl PartialEq for FileToEmbed {
 29    fn eq(&self, other: &Self) -> bool {
 30        self.worktree_id == other.worktree_id
 31            && self.path == other.path
 32            && self.mtime == other.mtime
 33            && self.spans == other.spans
 34    }
 35}
 36
 37pub struct EmbeddingQueue {
 38    embedding_provider: Arc<dyn EmbeddingProvider>,
 39    pending_batch: Vec<FileFragmentToEmbed>,
 40    executor: BackgroundExecutor,
 41    pending_batch_token_count: usize,
 42    finished_files_tx: channel::Sender<FileToEmbed>,
 43    finished_files_rx: channel::Receiver<FileToEmbed>,
 44}
 45
 46#[derive(Clone)]
 47pub struct FileFragmentToEmbed {
 48    file: Arc<Mutex<FileToEmbed>>,
 49    span_range: Range<usize>,
 50}
 51
 52impl EmbeddingQueue {
 53    pub fn new(
 54        embedding_provider: Arc<dyn EmbeddingProvider>,
 55        executor: BackgroundExecutor,
 56    ) -> Self {
 57        let (finished_files_tx, finished_files_rx) = channel::unbounded();
 58        Self {
 59            embedding_provider,
 60            executor,
 61            pending_batch: Vec::new(),
 62            pending_batch_token_count: 0,
 63            finished_files_tx,
 64            finished_files_rx,
 65        }
 66    }
 67
 68    pub fn push(&mut self, file: FileToEmbed) {
 69        if file.spans.is_empty() {
 70            self.finished_files_tx.try_send(file).unwrap();
 71            return;
 72        }
 73
 74        let file = Arc::new(Mutex::new(file));
 75
 76        self.pending_batch.push(FileFragmentToEmbed {
 77            file: file.clone(),
 78            span_range: 0..0,
 79        });
 80
 81        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
 82        for (ix, span) in file.lock().spans.iter().enumerate() {
 83            let span_token_count = if span.embedding.is_none() {
 84                span.token_count
 85            } else {
 86                0
 87            };
 88
 89            let next_token_count = self.pending_batch_token_count + span_token_count;
 90            if next_token_count > self.embedding_provider.max_tokens_per_batch() {
 91                let range_end = fragment_range.end;
 92                self.flush();
 93                self.pending_batch.push(FileFragmentToEmbed {
 94                    file: file.clone(),
 95                    span_range: range_end..range_end,
 96                });
 97                fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
 98            }
 99
100            fragment_range.end = ix + 1;
101            self.pending_batch_token_count += span_token_count;
102        }
103    }
104
105    pub fn flush(&mut self) {
106        let batch = mem::take(&mut self.pending_batch);
107        self.pending_batch_token_count = 0;
108        if batch.is_empty() {
109            return;
110        }
111
112        let finished_files_tx = self.finished_files_tx.clone();
113        let embedding_provider = self.embedding_provider.clone();
114
115        self.executor
116            .spawn(async move {
117                let mut spans = Vec::new();
118                for fragment in &batch {
119                    let file = fragment.file.lock();
120                    spans.extend(
121                        file.spans[fragment.span_range.clone()]
122                            .iter()
123                            .filter(|d| d.embedding.is_none())
124                            .map(|d| d.content.clone()),
125                    );
126                }
127
128                // If spans is 0, just send the fragment to the finished files if its the last one.
129                if spans.is_empty() {
130                    for fragment in batch.clone() {
131                        if let Some(file) = Arc::into_inner(fragment.file) {
132                            finished_files_tx.try_send(file.into_inner()).unwrap();
133                        }
134                    }
135                    return;
136                };
137
138                match embedding_provider.embed_batch(spans).await {
139                    Ok(embeddings) => {
140                        let mut embeddings = embeddings.into_iter();
141                        for fragment in batch {
142                            for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
143                                .iter_mut()
144                                .filter(|d| d.embedding.is_none())
145                            {
146                                if let Some(embedding) = embeddings.next() {
147                                    span.embedding = Some(embedding);
148                                } else {
149                                    log::error!("number of embeddings != number of documents");
150                                }
151                            }
152
153                            if let Some(file) = Arc::into_inner(fragment.file) {
154                                finished_files_tx.try_send(file.into_inner()).unwrap();
155                            }
156                        }
157                    }
158                    Err(error) => {
159                        log::error!("{:?}", error);
160                    }
161                }
162            })
163            .detach();
164    }
165
166    pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
167        self.finished_files_rx.clone()
168    }
169}