embedding_queue.rs

  1use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle};
  2use gpui::executor::Background;
  3use parking_lot::Mutex;
  4use smol::channel;
  5use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
  6
  7#[derive(Clone)]
  8pub struct FileToEmbed {
  9    pub worktree_id: i64,
 10    pub path: Arc<Path>,
 11    pub mtime: SystemTime,
 12    pub spans: Vec<Span>,
 13    pub job_handle: JobHandle,
 14}
 15
 16impl std::fmt::Debug for FileToEmbed {
 17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 18        f.debug_struct("FileToEmbed")
 19            .field("worktree_id", &self.worktree_id)
 20            .field("path", &self.path)
 21            .field("mtime", &self.mtime)
 22            .field("spans", &self.spans)
 23            .finish_non_exhaustive()
 24    }
 25}
 26
 27impl PartialEq for FileToEmbed {
 28    fn eq(&self, other: &Self) -> bool {
 29        self.worktree_id == other.worktree_id
 30            && self.path == other.path
 31            && self.mtime == other.mtime
 32            && self.spans == other.spans
 33    }
 34}
 35
 36pub struct EmbeddingQueue {
 37    embedding_provider: Arc<dyn EmbeddingProvider>,
 38    pending_batch: Vec<FileFragmentToEmbed>,
 39    executor: Arc<Background>,
 40    pending_batch_token_count: usize,
 41    finished_files_tx: channel::Sender<FileToEmbed>,
 42    finished_files_rx: channel::Receiver<FileToEmbed>,
 43}
 44
 45#[derive(Clone)]
 46pub struct FileFragmentToEmbed {
 47    file: Arc<Mutex<FileToEmbed>>,
 48    span_range: Range<usize>,
 49}
 50
 51impl EmbeddingQueue {
 52    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
 53        let (finished_files_tx, finished_files_rx) = channel::unbounded();
 54        Self {
 55            embedding_provider,
 56            executor,
 57            pending_batch: Vec::new(),
 58            pending_batch_token_count: 0,
 59            finished_files_tx,
 60            finished_files_rx,
 61        }
 62    }
 63
 64    pub fn push(&mut self, file: FileToEmbed) {
 65        if file.spans.is_empty() {
 66            self.finished_files_tx.try_send(file).unwrap();
 67            return;
 68        }
 69
 70        let file = Arc::new(Mutex::new(file));
 71
 72        self.pending_batch.push(FileFragmentToEmbed {
 73            file: file.clone(),
 74            span_range: 0..0,
 75        });
 76
 77        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
 78        for (ix, span) in file.lock().spans.iter().enumerate() {
 79            let span_token_count = if span.embedding.is_none() {
 80                span.token_count
 81            } else {
 82                0
 83            };
 84
 85            let next_token_count = self.pending_batch_token_count + span_token_count;
 86            if next_token_count > self.embedding_provider.max_tokens_per_batch() {
 87                let range_end = fragment_range.end;
 88                self.flush();
 89                self.pending_batch.push(FileFragmentToEmbed {
 90                    file: file.clone(),
 91                    span_range: range_end..range_end,
 92                });
 93                fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
 94            }
 95
 96            fragment_range.end = ix + 1;
 97            self.pending_batch_token_count += span_token_count;
 98        }
 99    }
100
101    pub fn flush(&mut self) {
102        let batch = mem::take(&mut self.pending_batch);
103        self.pending_batch_token_count = 0;
104        if batch.is_empty() {
105            return;
106        }
107
108        let finished_files_tx = self.finished_files_tx.clone();
109        let embedding_provider = self.embedding_provider.clone();
110
111        self.executor
112            .spawn(async move {
113                let mut spans = Vec::new();
114                for fragment in &batch {
115                    let file = fragment.file.lock();
116                    spans.extend(
117                        file.spans[fragment.span_range.clone()]
118                            .iter()
119                            .filter(|d| d.embedding.is_none())
120                            .map(|d| d.content.clone()),
121                    );
122                }
123
124                // If spans is 0, just send the fragment to the finished files if its the last one.
125                if spans.is_empty() {
126                    for fragment in batch.clone() {
127                        if let Some(file) = Arc::into_inner(fragment.file) {
128                            finished_files_tx.try_send(file.into_inner()).unwrap();
129                        }
130                    }
131                    return;
132                };
133
134                match embedding_provider.embed_batch(spans).await {
135                    Ok(embeddings) => {
136                        let mut embeddings = embeddings.into_iter();
137                        for fragment in batch {
138                            for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
139                                .iter_mut()
140                                .filter(|d| d.embedding.is_none())
141                            {
142                                if let Some(embedding) = embeddings.next() {
143                                    span.embedding = Some(embedding);
144                                } else {
145                                    log::error!("number of embeddings != number of documents");
146                                }
147                            }
148
149                            if let Some(file) = Arc::into_inner(fragment.file) {
150                                finished_files_tx.try_send(file.into_inner()).unwrap();
151                            }
152                        }
153                    }
154                    Err(error) => {
155                        log::error!("{:?}", error);
156                    }
157                }
158            })
159            .detach();
160    }
161
162    pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
163        self.finished_files_rx.clone()
164    }
165}