1use std::ops::Range;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::{HashSet, IndexSet};
7use futures::future::join_all;
8use futures::{self, FutureExt};
9use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
10use language::Buffer;
11use language_model::LanguageModelImage;
12use project::{Project, ProjectItem, ProjectPath, Symbol};
13use prompt_store::UserPromptId;
14use ref_cast::RefCast as _;
15use text::{Anchor, OffsetRangeExt};
16use util::ResultExt as _;
17
18use crate::ThreadStore;
19use crate::context::{
20 AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
21 FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
22 SymbolContextHandle, ThreadContextHandle,
23};
24use crate::context_strip::SuggestedContext;
25use crate::thread::{Thread, ThreadId};
26
27pub struct ContextStore {
28 project: WeakEntity<Project>,
29 thread_store: Option<WeakEntity<ThreadStore>>,
30 thread_summary_tasks: Vec<Task<()>>,
31 next_context_id: ContextId,
32 context_set: IndexSet<AgentContextKey>,
33 context_thread_ids: HashSet<ThreadId>,
34}
35
36impl ContextStore {
37 pub fn new(
38 project: WeakEntity<Project>,
39 thread_store: Option<WeakEntity<ThreadStore>>,
40 ) -> Self {
41 Self {
42 project,
43 thread_store,
44 thread_summary_tasks: Vec::new(),
45 next_context_id: ContextId::zero(),
46 context_set: IndexSet::default(),
47 context_thread_ids: HashSet::default(),
48 }
49 }
50
51 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
52 self.context_set.iter().map(|entry| entry.as_ref())
53 }
54
55 pub fn clear(&mut self) {
56 self.context_set.clear();
57 self.context_thread_ids.clear();
58 }
59
60 pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
61 let existing_context = thread
62 .messages()
63 .flat_map(|message| {
64 message
65 .loaded_context
66 .contexts
67 .iter()
68 .map(|context| AgentContextKey(context.handle()))
69 })
70 .collect::<HashSet<_>>();
71 self.context_set
72 .iter()
73 .filter(|context| !existing_context.contains(context))
74 .map(|entry| entry.0.clone())
75 .collect::<Vec<_>>()
76 }
77
78 pub fn add_file_from_path(
79 &mut self,
80 project_path: ProjectPath,
81 remove_if_exists: bool,
82 cx: &mut Context<Self>,
83 ) -> Task<Result<()>> {
84 let Some(project) = self.project.upgrade() else {
85 return Task::ready(Err(anyhow!("failed to read project")));
86 };
87
88 cx.spawn(async move |this, cx| {
89 let open_buffer_task = project.update(cx, |project, cx| {
90 project.open_buffer(project_path.clone(), cx)
91 })?;
92 let buffer = open_buffer_task.await?;
93 this.update(cx, |this, cx| {
94 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
95 })
96 })
97 }
98
99 pub fn add_file_from_buffer(
100 &mut self,
101 project_path: &ProjectPath,
102 buffer: Entity<Buffer>,
103 remove_if_exists: bool,
104 cx: &mut Context<Self>,
105 ) {
106 let context_id = self.next_context_id.post_inc();
107 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
108
109 let already_included = if self.has_context(&context) {
110 if remove_if_exists {
111 self.remove_context(&context, cx);
112 }
113 true
114 } else {
115 self.path_included_in_directory(project_path, cx).is_some()
116 };
117
118 if !already_included {
119 self.insert_context(context, cx);
120 }
121 }
122
123 pub fn add_directory(
124 &mut self,
125 project_path: &ProjectPath,
126 remove_if_exists: bool,
127 cx: &mut Context<Self>,
128 ) -> Result<()> {
129 let Some(project) = self.project.upgrade() else {
130 return Err(anyhow!("failed to read project"));
131 };
132
133 let Some(entry_id) = project
134 .read(cx)
135 .entry_for_path(project_path, cx)
136 .map(|entry| entry.id)
137 else {
138 return Err(anyhow!("no entry found for directory context"));
139 };
140
141 let context_id = self.next_context_id.post_inc();
142 let context = AgentContextHandle::Directory(DirectoryContextHandle {
143 entry_id,
144 context_id,
145 });
146
147 if self.has_context(&context) {
148 if remove_if_exists {
149 self.remove_context(&context, cx);
150 }
151 } else if self.path_included_in_directory(project_path, cx).is_none() {
152 self.insert_context(context, cx);
153 }
154
155 anyhow::Ok(())
156 }
157
158 pub fn add_symbol(
159 &mut self,
160 buffer: Entity<Buffer>,
161 symbol: SharedString,
162 range: Range<Anchor>,
163 enclosing_range: Range<Anchor>,
164 remove_if_exists: bool,
165 cx: &mut Context<Self>,
166 ) -> bool {
167 let context_id = self.next_context_id.post_inc();
168 let context = AgentContextHandle::Symbol(SymbolContextHandle {
169 buffer,
170 symbol,
171 range,
172 enclosing_range,
173 context_id,
174 });
175
176 if self.has_context(&context) {
177 if remove_if_exists {
178 self.remove_context(&context, cx);
179 }
180 return false;
181 }
182
183 self.insert_context(context, cx)
184 }
185
186 pub fn add_thread(
187 &mut self,
188 thread: Entity<Thread>,
189 remove_if_exists: bool,
190 cx: &mut Context<Self>,
191 ) {
192 let context_id = self.next_context_id.post_inc();
193 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
194
195 if self.has_context(&context) {
196 if remove_if_exists {
197 self.remove_context(&context, cx);
198 }
199 } else {
200 self.insert_context(context, cx);
201 }
202 }
203
204 fn start_summarizing_thread_if_needed(
205 &mut self,
206 thread: &Entity<Thread>,
207 cx: &mut Context<Self>,
208 ) {
209 if let Some(summary_task) =
210 thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx))
211 {
212 let thread = thread.clone();
213 let thread_store = self.thread_store.clone();
214
215 self.thread_summary_tasks.push(cx.spawn(async move |_, cx| {
216 summary_task.await;
217
218 if let Some(thread_store) = thread_store {
219 // Save thread so its summary can be reused later
220 let save_task = thread_store
221 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx));
222
223 if let Some(save_task) = save_task.ok() {
224 save_task.await.log_err();
225 }
226 }
227 }));
228 }
229 }
230
231 pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> {
232 let tasks = std::mem::take(&mut self.thread_summary_tasks);
233
234 cx.spawn(async move |_cx| {
235 join_all(tasks).await;
236 })
237 }
238
239 pub fn add_rules(
240 &mut self,
241 prompt_id: UserPromptId,
242 remove_if_exists: bool,
243 cx: &mut Context<ContextStore>,
244 ) {
245 let context_id = self.next_context_id.post_inc();
246 let context = AgentContextHandle::Rules(RulesContextHandle {
247 prompt_id,
248 context_id,
249 });
250
251 if self.has_context(&context) {
252 if remove_if_exists {
253 self.remove_context(&context, cx);
254 }
255 } else {
256 self.insert_context(context, cx);
257 }
258 }
259
260 pub fn add_fetched_url(
261 &mut self,
262 url: String,
263 text: impl Into<SharedString>,
264 cx: &mut Context<ContextStore>,
265 ) {
266 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
267 url: url.into(),
268 text: text.into(),
269 context_id: self.next_context_id.post_inc(),
270 });
271
272 self.insert_context(context, cx);
273 }
274
275 pub fn add_image(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
276 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
277 let context = AgentContextHandle::Image(ImageContext {
278 original_image: image,
279 image_task,
280 context_id: self.next_context_id.post_inc(),
281 });
282 self.insert_context(context, cx);
283 }
284
285 pub fn add_selection(
286 &mut self,
287 buffer: Entity<Buffer>,
288 range: Range<Anchor>,
289 cx: &mut Context<ContextStore>,
290 ) {
291 let context_id = self.next_context_id.post_inc();
292 let context = AgentContextHandle::Selection(SelectionContextHandle {
293 buffer,
294 range,
295 context_id,
296 });
297 self.insert_context(context, cx);
298 }
299
300 pub fn add_suggested_context(
301 &mut self,
302 suggested: &SuggestedContext,
303 cx: &mut Context<ContextStore>,
304 ) {
305 match suggested {
306 SuggestedContext::File {
307 buffer,
308 icon_path: _,
309 name: _,
310 } => {
311 if let Some(buffer) = buffer.upgrade() {
312 let context_id = self.next_context_id.post_inc();
313 self.insert_context(
314 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
315 cx,
316 );
317 };
318 }
319 SuggestedContext::Thread { thread, name: _ } => {
320 if let Some(thread) = thread.upgrade() {
321 let context_id = self.next_context_id.post_inc();
322 self.insert_context(
323 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
324 cx,
325 );
326 }
327 }
328 }
329 }
330
331 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
332 match &context {
333 AgentContextHandle::Thread(thread_context) => {
334 self.context_thread_ids
335 .insert(thread_context.thread.read(cx).id().clone());
336 self.start_summarizing_thread_if_needed(&thread_context.thread, cx);
337 }
338 _ => {}
339 }
340 let inserted = self.context_set.insert(AgentContextKey(context));
341 if inserted {
342 cx.notify();
343 }
344 inserted
345 }
346
347 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
348 if self
349 .context_set
350 .shift_remove(AgentContextKey::ref_cast(context))
351 {
352 match context {
353 AgentContextHandle::Thread(thread_context) => {
354 self.context_thread_ids
355 .remove(thread_context.thread.read(cx).id());
356 }
357 _ => {}
358 }
359 cx.notify();
360 }
361 }
362
363 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
364 self.context_set
365 .contains(AgentContextKey::ref_cast(context))
366 }
367
368 /// Returns whether this file path is already included directly in the context, or if it will be
369 /// included in the context via a directory.
370 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
371 let project = self.project.upgrade()?.read(cx);
372 self.context().find_map(|context| match context {
373 AgentContextHandle::File(file_context) => {
374 FileInclusion::check_file(file_context, path, cx)
375 }
376 AgentContextHandle::Directory(directory_context) => {
377 FileInclusion::check_directory(directory_context, path, project, cx)
378 }
379 _ => None,
380 })
381 }
382
383 pub fn path_included_in_directory(
384 &self,
385 path: &ProjectPath,
386 cx: &App,
387 ) -> Option<FileInclusion> {
388 let project = self.project.upgrade()?.read(cx);
389 self.context().find_map(|context| match context {
390 AgentContextHandle::Directory(directory_context) => {
391 FileInclusion::check_directory(directory_context, path, project, cx)
392 }
393 _ => None,
394 })
395 }
396
397 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
398 self.context().any(|context| match context {
399 AgentContextHandle::Symbol(context) => {
400 if context.symbol != symbol.name {
401 return false;
402 }
403 let buffer = context.buffer.read(cx);
404 let Some(context_path) = buffer.project_path(cx) else {
405 return false;
406 };
407 if context_path != symbol.path {
408 return false;
409 }
410 let context_range = context.range.to_point_utf16(&buffer.snapshot());
411 context_range.start == symbol.range.start.0
412 && context_range.end == symbol.range.end.0
413 }
414 _ => false,
415 })
416 }
417
418 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
419 self.context_thread_ids.contains(thread_id)
420 }
421
422 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
423 self.context_set
424 .contains(&RulesContextHandle::lookup_key(prompt_id))
425 }
426
427 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
428 self.context_set
429 .contains(&FetchedUrlContext::lookup_key(url.into()))
430 }
431
432 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
433 self.context()
434 .filter_map(|context| match context {
435 AgentContextHandle::File(file) => {
436 let buffer = file.buffer.read(cx);
437 buffer.project_path(cx)
438 }
439 AgentContextHandle::Directory(_)
440 | AgentContextHandle::Symbol(_)
441 | AgentContextHandle::Selection(_)
442 | AgentContextHandle::FetchedUrl(_)
443 | AgentContextHandle::Thread(_)
444 | AgentContextHandle::Rules(_)
445 | AgentContextHandle::Image(_) => None,
446 })
447 .collect()
448 }
449
450 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
451 &self.context_thread_ids
452 }
453}
454
455pub enum FileInclusion {
456 Direct,
457 InDirectory { full_path: PathBuf },
458}
459
460impl FileInclusion {
461 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
462 let file_path = file_context.buffer.read(cx).project_path(cx)?;
463 if path == &file_path {
464 Some(FileInclusion::Direct)
465 } else {
466 None
467 }
468 }
469
470 fn check_directory(
471 directory_context: &DirectoryContextHandle,
472 path: &ProjectPath,
473 project: &Project,
474 cx: &App,
475 ) -> Option<Self> {
476 let worktree = project
477 .worktree_for_entry(directory_context.entry_id, cx)?
478 .read(cx);
479 let entry = worktree.entry_for_id(directory_context.entry_id)?;
480 let directory_path = ProjectPath {
481 worktree_id: worktree.id(),
482 path: entry.path.clone(),
483 };
484 if path.starts_with(&directory_path) {
485 if path == &directory_path {
486 Some(FileInclusion::Direct)
487 } else {
488 Some(FileInclusion::InDirectory {
489 full_path: worktree.full_path(&entry.path),
490 })
491 }
492 } else {
493 None
494 }
495 }
496}