zeta_context.rs

  1use anyhow::{Result, anyhow};
  2use clap::{Parser, Subcommand};
  3use ordered_float::OrderedFloat;
  4use serde_json::json;
  5use std::fmt::Display;
  6use std::io::Write;
  7use std::path::Path;
  8use std::str::FromStr;
  9use std::{path::PathBuf, sync::Arc};
 10
 11#[derive(Parser, Debug)]
 12#[command(name = "zeta_context")]
 13struct Args {
 14    #[command(subcommand)]
 15    command: Command,
 16    #[arg(long, default_value_t = FileOrStdio::Stdio)]
 17    log: FileOrStdio,
 18}
 19
 20#[derive(Subcommand, Debug)]
 21enum Command {
 22    ShowIndex {
 23        directory: PathBuf,
 24    },
 25    NearbyReferences {
 26        cursor_position: SourceLocation,
 27        #[arg(long, default_value_t = 10)]
 28        context_lines: u32,
 29    },
 30
 31    Run {
 32        directory: PathBuf,
 33        cursor_position: CursorPosition,
 34        #[arg(long, default_value_t = 2048)]
 35        prompt_limit: usize,
 36        #[arg(long)]
 37        output_scores: Option<FileOrStdio>,
 38        #[command(flatten)]
 39        excerpt_options: ExcerptOptions,
 40    },
 41}
 42
 43#[derive(Clone, Debug)]
 44enum CursorPosition {
 45    Random,
 46    Specific(SourceLocation),
 47}
 48
 49impl CursorPosition {
 50    fn to_source_location_within(
 51        &self,
 52        languages: &[Arc<Language>],
 53        directory: &Path,
 54    ) -> SourceLocation {
 55        match self {
 56            CursorPosition::Random => {
 57                let entries = ignore::Walk::new(directory)
 58                    .filter_map(|result| result.ok())
 59                    .filter(|entry| language_for_file(languages, entry.path()).is_some())
 60                    .collect::<Vec<_>>();
 61                let selected_entry_ix = rand::random_range(0..entries.len());
 62                let path = entries[selected_entry_ix].path().to_path_buf();
 63                let source = std::fs::read_to_string(&path).unwrap();
 64                let offset = rand::random_range(0..source.len());
 65                let point = point_from_offset(&source, offset);
 66                let source_location = SourceLocation { path, point };
 67                log::info!("Selected random cursor position: {source_location}");
 68                source_location
 69            }
 70            CursorPosition::Specific(location) => location.clone(),
 71        }
 72    }
 73}
 74
 75impl Display for CursorPosition {
 76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 77        match self {
 78            CursorPosition::Random => write!(f, "random"),
 79            CursorPosition::Specific(location) => write!(f, "{}", &location),
 80        }
 81    }
 82}
 83
 84impl FromStr for CursorPosition {
 85    type Err = anyhow::Error;
 86
 87    fn from_str(s: &str) -> Result<Self, Self::Err> {
 88        match s {
 89            "random" => Ok(CursorPosition::Random),
 90            _ => Ok(CursorPosition::Specific(SourceLocation::from_str(s)?)),
 91        }
 92    }
 93}
 94
 95#[derive(Debug, Clone)]
 96enum FileOrStdio {
 97    File(PathBuf),
 98    Stdio,
 99}
100
101impl FileOrStdio {
102    #[allow(dead_code)]
103    fn read_to_string(&self) -> Result<String, std::io::Error> {
104        match self {
105            FileOrStdio::File(path) => std::fs::read_to_string(path),
106            FileOrStdio::Stdio => std::io::read_to_string(std::io::stdin()),
107        }
108    }
109
110    fn write_file_or_stdout(&self) -> Result<Box<dyn Write + Send + 'static>, std::io::Error> {
111        match self {
112            FileOrStdio::File(path) => Ok(Box::new(std::fs::File::create(path)?)),
113            FileOrStdio::Stdio => Ok(Box::new(std::io::stdout())),
114        }
115    }
116
117    fn write_file_or_stderr(
118        &self,
119    ) -> Result<Box<dyn std::io::Write + Send + 'static>, std::io::Error> {
120        match self {
121            FileOrStdio::File(path) => Ok(Box::new(std::fs::File::create(path)?)),
122            FileOrStdio::Stdio => Ok(Box::new(std::io::stderr())),
123        }
124    }
125}
126
127impl Display for FileOrStdio {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        match self {
130            FileOrStdio::File(path) => write!(f, "{}", path.display()),
131            FileOrStdio::Stdio => write!(f, "-"),
132        }
133    }
134}
135
136impl FromStr for FileOrStdio {
137    type Err = <PathBuf as FromStr>::Err;
138
139    fn from_str(s: &str) -> Result<Self, Self::Err> {
140        match s {
141            "-" => Ok(Self::Stdio),
142            _ => Ok(Self::File(PathBuf::from_str(s)?)),
143        }
144    }
145}
146
147fn main() -> Result<()> {
148    let args = ZetaContextArgs::parse();
149    env_logger::Builder::from_default_env()
150        .target(env_logger::Target::Pipe(args.log.write_file_or_stderr()?))
151        .init();
152    let languages = load_languages();
153    match &args.command {
154        Command::ShowIndex { directory } => {
155            /*
156            let directory = directory.canonicalize()?;
157            let index = IdentifierIndex::index_path(&languages, &directory)?;
158            for ((identifier, language_name), files) in &index.identifier_to_definitions {
159                println!("\n{} ({})", identifier.0, language_name.0);
160                for (file, definitions) in files {
161                    println!("  {:?}", file);
162                    for definition in definitions {
163                        println!("    {}", definition.path_string(&index));
164                    }
165                }
166            }
167            */
168            Ok(())
169        }
170
171        Command::NearbyReferences {
172            cursor_position,
173            context_lines,
174        } => {
175            /*
176            let (language, source, tree) = parse_file(&languages, &cursor_position.path)?;
177            let start_offset = offset_from_point(
178                &source,
179                Point::new(cursor_position.point.row.saturating_sub(*context_lines), 0),
180            );
181            let end_offset = offset_from_point(
182                &source,
183                Point::new(cursor_position.point.row + context_lines, 0),
184            );
185            let references = local_identifiers(
186                ReferenceRegion::Nearby,
187                &language,
188                &tree,
189                &source,
190                start_offset..end_offset,
191            );
192            for reference in references {
193                println!(
194                    "{:?} {}",
195                    point_range_from_offset_range(&source, reference.range),
196                    reference.identifier.0,
197                );
198            }
199            */
200            Ok(())
201        }
202
203        Command::Run {
204            directory,
205            cursor_position,
206            prompt_limit,
207            output_scores,
208            excerpt_options,
209        } => {
210            let directory = directory.canonicalize()?;
211            let index = IdentifierIndex::index_path(&languages, &directory)?;
212            let cursor_position = cursor_position.to_source_location_within(&languages, &directory);
213            let excerpt_file: Arc<Path> = cursor_position.path.as_path().into();
214            let (language, source, tree) = parse_file(&languages, &excerpt_file)?;
215            let cursor_offset = offset_from_point(&source, cursor_position.point);
216            let Some(excerpt_ranges) = ExcerptRangesInput {
217                language: &language,
218                tree: &tree,
219                source: &source,
220                cursor_offset,
221                options: excerpt_options,
222            }
223            .select() else {
224                return Err(anyhow!("line containing cursor does not fit within window"));
225            };
226            let mut snippets = gather_snippets(
227                &language,
228                &index,
229                &tree,
230                &excerpt_file,
231                &source,
232                excerpt_ranges.clone(),
233                cursor_offset,
234            );
235            let planned_prompt = PromptPlanner::populate(
236                &index,
237                snippets.clone(),
238                excerpt_file,
239                excerpt_ranges.clone(),
240                cursor_offset,
241                *prompt_limit,
242                &directory,
243            );
244            let prompt_string = planned_prompt.to_prompt_string(&index);
245            println!("{}", &prompt_string);
246
247            if let Some(output_scores) = output_scores {
248                snippets.sort_by_key(|snippet| OrderedFloat(-snippet.scores.signature));
249                let writer = output_scores.write_file_or_stdout()?;
250                serde_json::to_writer_pretty(
251                    writer,
252                    &snippets
253                        .into_iter()
254                        .map(|snippet| {
255                            json!({
256                                "file": snippet.definition_file,
257                                "symbol_path": snippet.definition.path_string(&index),
258                                "signature_score": snippet.scores.signature,
259                                "definition_score": snippet.scores.definition,
260                                "signature_score_density": snippet.score_density(&index, SnippetStyle::Signature),
261                                "definition_score_density": snippet.score_density(&index, SnippetStyle::Definition),
262                                "score_components": snippet.score_components
263                            })
264                        })
265                        .collect::<Vec<_>>(),
266                )?;
267            }
268
269            let actual_window_size = range_size(excerpt_ranges.excerpt_range);
270            if actual_window_size > excerpt_options.window_max_bytes {
271                let exceeded_amount = actual_window_size - excerpt_options.window_max_bytes;
272                if exceeded_amount as f64 / excerpt_options.window_max_bytes as f64 > 0.05 {
273                    log::error!("Exceeded max main excerpt size by {exceeded_amount} bytes");
274                }
275            }
276
277            if prompt_string.len() > *prompt_limit {
278                let exceeded_amount = prompt_string.len() - *prompt_limit;
279                if exceeded_amount as f64 / *prompt_limit as f64 > 0.1 {
280                    log::error!(
281                        "Exceeded max prompt size of {prompt_limit} bytes by {exceeded_amount} bytes"
282                    );
283                }
284            }
285
286            Ok(())
287        }
288    }
289}