embedding_queue.rs

  1use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
  2
  3use gpui::AppContext;
  4use parking_lot::Mutex;
  5use smol::channel;
  6
  7use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
  8
  9#[derive(Clone)]
 10pub struct FileToEmbed {
 11    pub worktree_id: i64,
 12    pub path: PathBuf,
 13    pub mtime: SystemTime,
 14    pub documents: Vec<Document>,
 15    pub job_handle: JobHandle,
 16}
 17
 18impl std::fmt::Debug for FileToEmbed {
 19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 20        f.debug_struct("FileToEmbed")
 21            .field("worktree_id", &self.worktree_id)
 22            .field("path", &self.path)
 23            .field("mtime", &self.mtime)
 24            .field("document", &self.documents)
 25            .finish_non_exhaustive()
 26    }
 27}
 28
 29impl PartialEq for FileToEmbed {
 30    fn eq(&self, other: &Self) -> bool {
 31        self.worktree_id == other.worktree_id
 32            && self.path == other.path
 33            && self.mtime == other.mtime
 34            && self.documents == other.documents
 35    }
 36}
 37
 38pub struct EmbeddingQueue {
 39    embedding_provider: Arc<dyn EmbeddingProvider>,
 40    pending_batch: Vec<FileToEmbedFragment>,
 41    pending_batch_token_count: usize,
 42    finished_files_tx: channel::Sender<FileToEmbed>,
 43    finished_files_rx: channel::Receiver<FileToEmbed>,
 44}
 45
 46pub struct FileToEmbedFragment {
 47    file: Arc<Mutex<FileToEmbed>>,
 48    document_range: Range<usize>,
 49}
 50
 51impl EmbeddingQueue {
 52    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
 53        let (finished_files_tx, finished_files_rx) = channel::unbounded();
 54        Self {
 55            embedding_provider,
 56            pending_batch: Vec::new(),
 57            pending_batch_token_count: 0,
 58            finished_files_tx,
 59            finished_files_rx,
 60        }
 61    }
 62
 63    pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) {
 64        let file = Arc::new(Mutex::new(file));
 65
 66        self.pending_batch.push(FileToEmbedFragment {
 67            file: file.clone(),
 68            document_range: 0..0,
 69        });
 70
 71        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
 72        for (ix, document) in file.lock().documents.iter().enumerate() {
 73            let next_token_count = self.pending_batch_token_count + document.token_count;
 74            if next_token_count > self.embedding_provider.max_tokens_per_batch() {
 75                let range_end = fragment_range.end;
 76                self.flush(cx);
 77                self.pending_batch.push(FileToEmbedFragment {
 78                    file: file.clone(),
 79                    document_range: range_end..range_end,
 80                });
 81                fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
 82            }
 83
 84            fragment_range.end = ix + 1;
 85            self.pending_batch_token_count += document.token_count;
 86        }
 87    }
 88
 89    pub fn flush(&mut self, cx: &mut AppContext) {
 90        let batch = mem::take(&mut self.pending_batch);
 91        self.pending_batch_token_count = 0;
 92        if batch.is_empty() {
 93            return;
 94        }
 95
 96        let finished_files_tx = self.finished_files_tx.clone();
 97        let embedding_provider = self.embedding_provider.clone();
 98        cx.background().spawn(async move {
 99            let mut spans = Vec::new();
100            for fragment in &batch {
101                let file = fragment.file.lock();
102                spans.extend(
103                    file.documents[fragment.document_range.clone()]
104                        .iter()
105                        .map(|d| d.content.clone()),
106                );
107            }
108
109            match embedding_provider.embed_batch(spans).await {
110                Ok(embeddings) => {
111                    let mut embeddings = embeddings.into_iter();
112                    for fragment in batch {
113                        for document in
114                            &mut fragment.file.lock().documents[fragment.document_range.clone()]
115                        {
116                            if let Some(embedding) = embeddings.next() {
117                                document.embedding = embedding;
118                            } else {
119                                //
120                                log::error!("number of embeddings returned different from number of documents");
121                            }
122                        }
123
124                        if let Some(file) = Arc::into_inner(fragment.file) {
125                            finished_files_tx.try_send(file.into_inner()).unwrap();
126                        }
127                    }
128                }
129                Err(error) => {
130                    log::error!("{:?}", error);
131                }
132            }
133        })
134        .detach();
135    }
136
137    pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
138        self.finished_files_rx.clone()
139    }
140}