project_command.rs

  1use std::{
  2    fmt::Write as _,
  3    ops::DerefMut,
  4    sync::{atomic::AtomicBool, Arc},
  5};
  6
  7use anyhow::{anyhow, Result};
  8use assistant_slash_command::{
  9    ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
 10    SlashCommandResult,
 11};
 12use feature_flags::FeatureFlag;
 13use gpui::{App, Task, WeakEntity};
 14use language::{Anchor, CodeLabel, LspAdapterDelegate};
 15use language_model::{LanguageModelRegistry, LanguageModelTool};
 16use prompt_store::PromptBuilder;
 17use schemars::JsonSchema;
 18use semantic_index::SemanticDb;
 19use serde::Deserialize;
 20use ui::prelude::*;
 21use workspace::Workspace;
 22
 23use super::{create_label_for_command, search_command::add_search_result_section};
 24
 25pub struct ProjectSlashCommandFeatureFlag;
 26
 27impl FeatureFlag for ProjectSlashCommandFeatureFlag {
 28    const NAME: &'static str = "project-slash-command";
 29}
 30
 31pub struct ProjectSlashCommand {
 32    prompt_builder: Arc<PromptBuilder>,
 33}
 34
 35impl ProjectSlashCommand {
 36    pub fn new(prompt_builder: Arc<PromptBuilder>) -> Self {
 37        Self { prompt_builder }
 38    }
 39}
 40
 41impl SlashCommand for ProjectSlashCommand {
 42    fn name(&self) -> String {
 43        "project".into()
 44    }
 45
 46    fn label(&self, cx: &App) -> CodeLabel {
 47        create_label_for_command("project", &[], cx)
 48    }
 49
 50    fn description(&self) -> String {
 51        "Generate a semantic search based on context".into()
 52    }
 53
 54    fn icon(&self) -> IconName {
 55        IconName::Folder
 56    }
 57
 58    fn menu_text(&self) -> String {
 59        self.description()
 60    }
 61
 62    fn requires_argument(&self) -> bool {
 63        false
 64    }
 65
 66    fn complete_argument(
 67        self: Arc<Self>,
 68        _arguments: &[String],
 69        _cancel: Arc<AtomicBool>,
 70        _workspace: Option<WeakEntity<Workspace>>,
 71        _window: &mut Window,
 72        _cx: &mut App,
 73    ) -> Task<Result<Vec<ArgumentCompletion>>> {
 74        Task::ready(Ok(Vec::new()))
 75    }
 76
 77    fn run(
 78        self: Arc<Self>,
 79        _arguments: &[String],
 80        _context_slash_command_output_sections: &[SlashCommandOutputSection<Anchor>],
 81        context_buffer: language::BufferSnapshot,
 82        workspace: WeakEntity<Workspace>,
 83        _delegate: Option<Arc<dyn LspAdapterDelegate>>,
 84        window: &mut Window,
 85        cx: &mut App,
 86    ) -> Task<SlashCommandResult> {
 87        let model_registry = LanguageModelRegistry::read_global(cx);
 88        let current_model = model_registry.active_model();
 89        let prompt_builder = self.prompt_builder.clone();
 90
 91        let Some(workspace) = workspace.upgrade() else {
 92            return Task::ready(Err(anyhow::anyhow!("workspace was dropped")));
 93        };
 94        let project = workspace.read(cx).project().clone();
 95        let fs = project.read(cx).fs().clone();
 96        let Some(project_index) =
 97            cx.update_global(|index: &mut SemanticDb, cx| index.project_index(project, cx))
 98        else {
 99            return Task::ready(Err(anyhow::anyhow!("no project indexer")));
100        };
101
102        window.spawn(cx, async move |cx| {
103            let current_model = current_model.ok_or_else(|| anyhow!("no model selected"))?;
104
105            let prompt =
106                prompt_builder.generate_project_slash_command_prompt(context_buffer.text())?;
107
108            let search_queries = current_model
109                .use_tool::<SearchQueries>(
110                    language_model::LanguageModelRequest {
111                        messages: vec![language_model::LanguageModelRequestMessage {
112                            role: language_model::Role::User,
113                            content: vec![language_model::MessageContent::Text(prompt)],
114                            cache: false,
115                        }],
116                        tools: vec![],
117                        stop: vec![],
118                        temperature: None,
119                    },
120                    cx.deref_mut(),
121                )
122                .await?
123                .search_queries;
124
125            let results = project_index
126                .read_with(cx, |project_index, cx| {
127                    project_index.search(search_queries.clone(), 25, cx)
128                })?
129                .await?;
130
131            let results = SemanticDb::load_results(results, &fs, &cx).await?;
132
133            cx.background_spawn(async move {
134                let mut output = "Project context:\n".to_string();
135                let mut sections = Vec::new();
136
137                for (ix, query) in search_queries.into_iter().enumerate() {
138                    let start_ix = output.len();
139                    writeln!(&mut output, "Results for {query}:").unwrap();
140                    let mut has_results = false;
141                    for result in &results {
142                        if result.query_index == ix {
143                            add_search_result_section(result, &mut output, &mut sections);
144                            has_results = true;
145                        }
146                    }
147                    if has_results {
148                        sections.push(SlashCommandOutputSection {
149                            range: start_ix..output.len(),
150                            icon: IconName::MagnifyingGlass,
151                            label: query.into(),
152                            metadata: None,
153                        });
154                        output.push('\n');
155                    } else {
156                        output.truncate(start_ix);
157                    }
158                }
159
160                sections.push(SlashCommandOutputSection {
161                    range: 0..output.len(),
162                    icon: IconName::Book,
163                    label: "Project context".into(),
164                    metadata: None,
165                });
166
167                Ok(SlashCommandOutput {
168                    text: output,
169                    sections,
170                    run_commands_in_text: true,
171                }
172                .to_event_stream())
173            })
174            .await
175        })
176    }
177}
178
179#[derive(JsonSchema, Deserialize)]
180struct SearchQueries {
181    /// An array of semantic search queries.
182    ///
183    /// These queries will be used to search the user's codebase.
184    /// The function can only accept 4 queries, otherwise it will error.
185    /// As such, it's important that you limit the length of the search_queries array to 5 queries or less.
186    search_queries: Vec<String>,
187}
188
189impl LanguageModelTool for SearchQueries {
190    fn name() -> String {
191        "search_queries".to_string()
192    }
193
194    fn description() -> String {
195        "Generate semantic search queries based on context".to_string()
196    }
197}