1use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
2
3use gpui::AppContext;
4use parking_lot::Mutex;
5use smol::channel;
6
7use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
8
9#[derive(Clone)]
10pub struct FileToEmbed {
11 pub worktree_id: i64,
12 pub path: PathBuf,
13 pub mtime: SystemTime,
14 pub documents: Vec<Document>,
15 pub job_handle: JobHandle,
16}
17
18impl std::fmt::Debug for FileToEmbed {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("FileToEmbed")
21 .field("worktree_id", &self.worktree_id)
22 .field("path", &self.path)
23 .field("mtime", &self.mtime)
24 .field("document", &self.documents)
25 .finish_non_exhaustive()
26 }
27}
28
29impl PartialEq for FileToEmbed {
30 fn eq(&self, other: &Self) -> bool {
31 self.worktree_id == other.worktree_id
32 && self.path == other.path
33 && self.mtime == other.mtime
34 && self.documents == other.documents
35 }
36}
37
38pub struct EmbeddingQueue {
39 embedding_provider: Arc<dyn EmbeddingProvider>,
40 pending_batch: Vec<FileToEmbedFragment>,
41 pending_batch_token_count: usize,
42 finished_files_tx: channel::Sender<FileToEmbed>,
43 finished_files_rx: channel::Receiver<FileToEmbed>,
44}
45
46pub struct FileToEmbedFragment {
47 file: Arc<Mutex<FileToEmbed>>,
48 document_range: Range<usize>,
49}
50
51impl EmbeddingQueue {
52 pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
53 let (finished_files_tx, finished_files_rx) = channel::unbounded();
54 Self {
55 embedding_provider,
56 pending_batch: Vec::new(),
57 pending_batch_token_count: 0,
58 finished_files_tx,
59 finished_files_rx,
60 }
61 }
62
63 pub fn push(&mut self, file: FileToEmbed, cx: &mut AppContext) {
64 let file = Arc::new(Mutex::new(file));
65
66 self.pending_batch.push(FileToEmbedFragment {
67 file: file.clone(),
68 document_range: 0..0,
69 });
70
71 let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
72 for (ix, document) in file.lock().documents.iter().enumerate() {
73 let next_token_count = self.pending_batch_token_count + document.token_count;
74 if next_token_count > self.embedding_provider.max_tokens_per_batch() {
75 let range_end = fragment_range.end;
76 self.flush(cx);
77 self.pending_batch.push(FileToEmbedFragment {
78 file: file.clone(),
79 document_range: range_end..range_end,
80 });
81 fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
82 }
83
84 fragment_range.end = ix + 1;
85 self.pending_batch_token_count += document.token_count;
86 }
87 }
88
89 pub fn flush(&mut self, cx: &mut AppContext) {
90 let batch = mem::take(&mut self.pending_batch);
91 self.pending_batch_token_count = 0;
92 if batch.is_empty() {
93 return;
94 }
95
96 let finished_files_tx = self.finished_files_tx.clone();
97 let embedding_provider = self.embedding_provider.clone();
98 cx.background().spawn(async move {
99 let mut spans = Vec::new();
100 for fragment in &batch {
101 let file = fragment.file.lock();
102 spans.extend(
103 file.documents[fragment.document_range.clone()]
104 .iter()
105 .map(|d| d.content.clone()),
106 );
107 }
108
109 match embedding_provider.embed_batch(spans).await {
110 Ok(embeddings) => {
111 let mut embeddings = embeddings.into_iter();
112 for fragment in batch {
113 for document in
114 &mut fragment.file.lock().documents[fragment.document_range.clone()]
115 {
116 if let Some(embedding) = embeddings.next() {
117 document.embedding = embedding;
118 } else {
119 //
120 log::error!("number of embeddings returned different from number of documents");
121 }
122 }
123
124 if let Some(file) = Arc::into_inner(fragment.file) {
125 finished_files_tx.try_send(file.into_inner()).unwrap();
126 }
127 }
128 }
129 Err(error) => {
130 log::error!("{:?}", error);
131 }
132 }
133 })
134 .detach();
135 }
136
137 pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
138 self.finished_files_rx.clone()
139 }
140}