embedding_queue.rs

  1use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
  2use gpui::executor::Background;
  3use parking_lot::Mutex;
  4use smol::channel;
  5use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
  6
  7#[derive(Clone)]
  8pub struct FileToEmbed {
  9    pub worktree_id: i64,
 10    pub path: PathBuf,
 11    pub mtime: SystemTime,
 12    pub documents: Vec<Document>,
 13    pub job_handle: JobHandle,
 14}
 15
 16impl std::fmt::Debug for FileToEmbed {
 17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 18        f.debug_struct("FileToEmbed")
 19            .field("worktree_id", &self.worktree_id)
 20            .field("path", &self.path)
 21            .field("mtime", &self.mtime)
 22            .field("document", &self.documents)
 23            .finish_non_exhaustive()
 24    }
 25}
 26
 27impl PartialEq for FileToEmbed {
 28    fn eq(&self, other: &Self) -> bool {
 29        self.worktree_id == other.worktree_id
 30            && self.path == other.path
 31            && self.mtime == other.mtime
 32            && self.documents == other.documents
 33    }
 34}
 35
 36pub struct EmbeddingQueue {
 37    embedding_provider: Arc<dyn EmbeddingProvider>,
 38    pending_batch: Vec<FileToEmbedFragment>,
 39    executor: Arc<Background>,
 40    pending_batch_token_count: usize,
 41    finished_files_tx: channel::Sender<FileToEmbed>,
 42    finished_files_rx: channel::Receiver<FileToEmbed>,
 43}
 44
 45pub struct FileToEmbedFragment {
 46    file: Arc<Mutex<FileToEmbed>>,
 47    document_range: Range<usize>,
 48}
 49
 50impl EmbeddingQueue {
 51    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
 52        let (finished_files_tx, finished_files_rx) = channel::unbounded();
 53        Self {
 54            embedding_provider,
 55            executor,
 56            pending_batch: Vec::new(),
 57            pending_batch_token_count: 0,
 58            finished_files_tx,
 59            finished_files_rx,
 60        }
 61    }
 62
 63    pub fn push(&mut self, file: FileToEmbed) {
 64        if file.documents.is_empty() {
 65            self.finished_files_tx.try_send(file).unwrap();
 66            return;
 67        }
 68
 69        let file = Arc::new(Mutex::new(file));
 70
 71        self.pending_batch.push(FileToEmbedFragment {
 72            file: file.clone(),
 73            document_range: 0..0,
 74        });
 75
 76        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
 77        for (ix, document) in file.lock().documents.iter().enumerate() {
 78            let next_token_count = self.pending_batch_token_count + document.token_count;
 79            if next_token_count > self.embedding_provider.max_tokens_per_batch() {
 80                let range_end = fragment_range.end;
 81                self.flush();
 82                self.pending_batch.push(FileToEmbedFragment {
 83                    file: file.clone(),
 84                    document_range: range_end..range_end,
 85                });
 86                fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
 87            }
 88
 89            fragment_range.end = ix + 1;
 90            self.pending_batch_token_count += document.token_count;
 91        }
 92    }
 93
 94    pub fn flush(&mut self) {
 95        let batch = mem::take(&mut self.pending_batch);
 96        self.pending_batch_token_count = 0;
 97        if batch.is_empty() {
 98            return;
 99        }
100
101        let finished_files_tx = self.finished_files_tx.clone();
102        let embedding_provider = self.embedding_provider.clone();
103        self.executor.spawn(async move {
104            let mut spans = Vec::new();
105            for fragment in &batch {
106                let file = fragment.file.lock();
107                spans.extend(
108                    {
109                        file.documents[fragment.document_range.clone()]
110                            .iter()
111                            .map(|d| d.content.clone())
112                        }
113                );
114            }
115
116            match embedding_provider.embed_batch(spans).await {
117                Ok(embeddings) => {
118                    let mut embeddings = embeddings.into_iter();
119                    for fragment in batch {
120                        for document in
121                            &mut fragment.file.lock().documents[fragment.document_range.clone()]
122                        {
123                            if let Some(embedding) = embeddings.next() {
124                                document.embedding = embedding;
125                            } else {
126                                //
127                                log::error!("number of embeddings returned different from number of documents");
128                            }
129                        }
130
131                        if let Some(file) = Arc::into_inner(fragment.file) {
132                            finished_files_tx.try_send(file.into_inner()).unwrap();
133                        }
134                    }
135                }
136                Err(error) => {
137                    log::error!("{:?}", error);
138                }
139            }
140        })
141        .detach();
142    }
143
144    pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
145        self.finished_files_rx.clone()
146    }
147}