1use crate::{parsing::Span, JobHandle};
2use ai::embedding::EmbeddingProvider;
3use gpui::executor::Background;
4use parking_lot::Mutex;
5use smol::channel;
6use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
7
8#[derive(Clone)]
9pub struct FileToEmbed {
10 pub worktree_id: i64,
11 pub path: Arc<Path>,
12 pub mtime: SystemTime,
13 pub spans: Vec<Span>,
14 pub job_handle: JobHandle,
15}
16
17impl std::fmt::Debug for FileToEmbed {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("FileToEmbed")
20 .field("worktree_id", &self.worktree_id)
21 .field("path", &self.path)
22 .field("mtime", &self.mtime)
23 .field("spans", &self.spans)
24 .finish_non_exhaustive()
25 }
26}
27
28impl PartialEq for FileToEmbed {
29 fn eq(&self, other: &Self) -> bool {
30 self.worktree_id == other.worktree_id
31 && self.path == other.path
32 && self.mtime == other.mtime
33 && self.spans == other.spans
34 }
35}
36
37pub struct EmbeddingQueue {
38 embedding_provider: Arc<dyn EmbeddingProvider>,
39 pending_batch: Vec<FileFragmentToEmbed>,
40 executor: Arc<Background>,
41 pending_batch_token_count: usize,
42 finished_files_tx: channel::Sender<FileToEmbed>,
43 finished_files_rx: channel::Receiver<FileToEmbed>,
44 api_key: Option<String>,
45}
46
47#[derive(Clone)]
48pub struct FileFragmentToEmbed {
49 file: Arc<Mutex<FileToEmbed>>,
50 span_range: Range<usize>,
51}
52
53impl EmbeddingQueue {
54 pub fn new(
55 embedding_provider: Arc<dyn EmbeddingProvider>,
56 executor: Arc<Background>,
57 api_key: Option<String>,
58 ) -> Self {
59 let (finished_files_tx, finished_files_rx) = channel::unbounded();
60 Self {
61 embedding_provider,
62 executor,
63 pending_batch: Vec::new(),
64 pending_batch_token_count: 0,
65 finished_files_tx,
66 finished_files_rx,
67 api_key,
68 }
69 }
70
71 pub fn set_api_key(&mut self, api_key: Option<String>) {
72 self.api_key = api_key
73 }
74
75 pub fn push(&mut self, file: FileToEmbed) {
76 if file.spans.is_empty() {
77 self.finished_files_tx.try_send(file).unwrap();
78 return;
79 }
80
81 let file = Arc::new(Mutex::new(file));
82
83 self.pending_batch.push(FileFragmentToEmbed {
84 file: file.clone(),
85 span_range: 0..0,
86 });
87
88 let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
89 for (ix, span) in file.lock().spans.iter().enumerate() {
90 let span_token_count = if span.embedding.is_none() {
91 span.token_count
92 } else {
93 0
94 };
95
96 let next_token_count = self.pending_batch_token_count + span_token_count;
97 if next_token_count > self.embedding_provider.max_tokens_per_batch() {
98 let range_end = fragment_range.end;
99 self.flush();
100 self.pending_batch.push(FileFragmentToEmbed {
101 file: file.clone(),
102 span_range: range_end..range_end,
103 });
104 fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
105 }
106
107 fragment_range.end = ix + 1;
108 self.pending_batch_token_count += span_token_count;
109 }
110 }
111
112 pub fn flush(&mut self) {
113 let batch = mem::take(&mut self.pending_batch);
114 self.pending_batch_token_count = 0;
115 if batch.is_empty() {
116 return;
117 }
118
119 let finished_files_tx = self.finished_files_tx.clone();
120 let embedding_provider = self.embedding_provider.clone();
121 let api_key = self.api_key.clone();
122
123 self.executor
124 .spawn(async move {
125 let mut spans = Vec::new();
126 for fragment in &batch {
127 let file = fragment.file.lock();
128 spans.extend(
129 file.spans[fragment.span_range.clone()]
130 .iter()
131 .filter(|d| d.embedding.is_none())
132 .map(|d| d.content.clone()),
133 );
134 }
135
136 // If spans is 0, just send the fragment to the finished files if its the last one.
137 if spans.is_empty() {
138 for fragment in batch.clone() {
139 if let Some(file) = Arc::into_inner(fragment.file) {
140 finished_files_tx.try_send(file.into_inner()).unwrap();
141 }
142 }
143 return;
144 };
145
146 match embedding_provider.embed_batch(spans, api_key).await {
147 Ok(embeddings) => {
148 let mut embeddings = embeddings.into_iter();
149 for fragment in batch {
150 for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
151 .iter_mut()
152 .filter(|d| d.embedding.is_none())
153 {
154 if let Some(embedding) = embeddings.next() {
155 span.embedding = Some(embedding);
156 } else {
157 log::error!("number of embeddings != number of documents");
158 }
159 }
160
161 if let Some(file) = Arc::into_inner(fragment.file) {
162 finished_files_tx.try_send(file.into_inner()).unwrap();
163 }
164 }
165 }
166 Err(error) => {
167 log::error!("{:?}", error);
168 }
169 }
170 })
171 .detach();
172 }
173
174 pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
175 self.finished_files_rx.clone()
176 }
177}