mention.rs

  1use agent::ThreadId;
  2use anyhow::{Context as _, Result, bail};
  3use prompt_store::{PromptId, UserPromptId};
  4use std::{
  5    fmt,
  6    ops::Range,
  7    path::{Path, PathBuf},
  8};
  9use url::Url;
 10
 11#[derive(Clone, Debug, PartialEq, Eq)]
 12pub enum MentionUri {
 13    File(PathBuf),
 14    Symbol {
 15        path: PathBuf,
 16        name: String,
 17        line_range: Range<u32>,
 18    },
 19    Thread {
 20        id: ThreadId,
 21        name: String,
 22    },
 23    TextThread {
 24        path: PathBuf,
 25        name: String,
 26    },
 27    Rule {
 28        id: PromptId,
 29        name: String,
 30    },
 31    Selection {
 32        path: PathBuf,
 33        line_range: Range<u32>,
 34    },
 35    Fetch {
 36        url: Url,
 37    },
 38}
 39
 40impl MentionUri {
 41    pub fn parse(input: &str) -> Result<Self> {
 42        let url = url::Url::parse(input)?;
 43        let path = url.path();
 44        match url.scheme() {
 45            "file" => {
 46                if let Some(fragment) = url.fragment() {
 47                    let range = fragment
 48                        .strip_prefix("L")
 49                        .context("Line range must start with \"L\"")?;
 50                    let (start, end) = range
 51                        .split_once(":")
 52                        .context("Line range must use colon as separator")?;
 53                    let line_range = start
 54                        .parse::<u32>()
 55                        .context("Parsing line range start")?
 56                        .checked_sub(1)
 57                        .context("Line numbers should be 1-based")?
 58                        ..end
 59                            .parse::<u32>()
 60                            .context("Parsing line range end")?
 61                            .checked_sub(1)
 62                            .context("Line numbers should be 1-based")?;
 63                    if let Some(name) = single_query_param(&url, "symbol")? {
 64                        Ok(Self::Symbol {
 65                            name,
 66                            path: path.into(),
 67                            line_range,
 68                        })
 69                    } else {
 70                        Ok(Self::Selection {
 71                            path: path.into(),
 72                            line_range,
 73                        })
 74                    }
 75                } else {
 76                    Ok(Self::File(path.into()))
 77                }
 78            }
 79            "zed" => {
 80                if let Some(thread_id) = path.strip_prefix("/agent/thread/") {
 81                    let name = single_query_param(&url, "name")?.context("Missing thread name")?;
 82                    Ok(Self::Thread {
 83                        id: thread_id.into(),
 84                        name,
 85                    })
 86                } else if let Some(path) = path.strip_prefix("/agent/text-thread/") {
 87                    let name = single_query_param(&url, "name")?.context("Missing thread name")?;
 88                    Ok(Self::TextThread {
 89                        path: path.into(),
 90                        name,
 91                    })
 92                } else if let Some(rule_id) = path.strip_prefix("/agent/rule/") {
 93                    let name = single_query_param(&url, "name")?.context("Missing rule name")?;
 94                    let rule_id = UserPromptId(rule_id.parse()?);
 95                    Ok(Self::Rule {
 96                        id: rule_id.into(),
 97                        name,
 98                    })
 99                } else {
100                    bail!("invalid zed url: {:?}", input);
101                }
102            }
103            "http" | "https" => Ok(MentionUri::Fetch { url }),
104            other => bail!("unrecognized scheme {:?}", other),
105        }
106    }
107
108    fn name(&self) -> String {
109        match self {
110            MentionUri::File(path) => path
111                .file_name()
112                .unwrap_or_default()
113                .to_string_lossy()
114                .into_owned(),
115            MentionUri::Symbol { name, .. } => name.clone(),
116            MentionUri::Thread { name, .. } => name.clone(),
117            MentionUri::TextThread { name, .. } => name.clone(),
118            MentionUri::Rule { name, .. } => name.clone(),
119            MentionUri::Selection {
120                path, line_range, ..
121            } => selection_name(path, line_range),
122            MentionUri::Fetch { url } => url.to_string(),
123        }
124    }
125
126    pub fn as_link<'a>(&'a self) -> MentionLink<'a> {
127        MentionLink(self)
128    }
129
130    pub fn to_uri(&self) -> Url {
131        match self {
132            MentionUri::File(path) => {
133                let mut url = Url::parse("file:///").unwrap();
134                url.set_path(&path.to_string_lossy());
135                url
136            }
137            MentionUri::Symbol {
138                path,
139                name,
140                line_range,
141            } => {
142                let mut url = Url::parse("file:///").unwrap();
143                url.set_path(&path.to_string_lossy());
144                url.query_pairs_mut().append_pair("symbol", name);
145                url.set_fragment(Some(&format!(
146                    "L{}:{}",
147                    line_range.start + 1,
148                    line_range.end + 1
149                )));
150                url
151            }
152            MentionUri::Selection { path, line_range } => {
153                let mut url = Url::parse("file:///").unwrap();
154                url.set_path(&path.to_string_lossy());
155                url.set_fragment(Some(&format!(
156                    "L{}:{}",
157                    line_range.start + 1,
158                    line_range.end + 1
159                )));
160                url
161            }
162            MentionUri::Thread { name, id } => {
163                let mut url = Url::parse("zed:///").unwrap();
164                url.set_path(&format!("/agent/thread/{id}"));
165                url.query_pairs_mut().append_pair("name", name);
166                url
167            }
168            MentionUri::TextThread { path, name } => {
169                let mut url = Url::parse("zed:///").unwrap();
170                url.set_path(&format!("/agent/text-thread/{}", path.to_string_lossy()));
171                url.query_pairs_mut().append_pair("name", name);
172                url
173            }
174            MentionUri::Rule { name, id } => {
175                let mut url = Url::parse("zed:///").unwrap();
176                url.set_path(&format!("/agent/rule/{id}"));
177                url.query_pairs_mut().append_pair("name", name);
178                url
179            }
180            MentionUri::Fetch { url } => url.clone(),
181        }
182    }
183}
184
185pub struct MentionLink<'a>(&'a MentionUri);
186
187impl fmt::Display for MentionLink<'_> {
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        write!(f, "[@{}]({})", self.0.name(), self.0.to_uri())
190    }
191}
192
193fn single_query_param(url: &Url, name: &'static str) -> Result<Option<String>> {
194    let pairs = url.query_pairs().collect::<Vec<_>>();
195    match pairs.as_slice() {
196        [] => Ok(None),
197        [(k, v)] => {
198            if k != name {
199                bail!("invalid query parameter")
200            }
201
202            Ok(Some(v.to_string()))
203        }
204        _ => bail!("too many query pairs"),
205    }
206}
207
208pub fn selection_name(path: &Path, line_range: &Range<u32>) -> String {
209    format!(
210        "{} ({}:{})",
211        path.file_name().unwrap_or_default().display(),
212        line_range.start + 1,
213        line_range.end + 1
214    )
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_parse_file_uri() {
223        let file_uri = "file:///path/to/file.rs";
224        let parsed = MentionUri::parse(file_uri).unwrap();
225        match &parsed {
226            MentionUri::File(path) => assert_eq!(path.to_str().unwrap(), "/path/to/file.rs"),
227            _ => panic!("Expected File variant"),
228        }
229        assert_eq!(parsed.to_uri().to_string(), file_uri);
230    }
231
232    #[test]
233    fn test_parse_symbol_uri() {
234        let symbol_uri = "file:///path/to/file.rs?symbol=MySymbol#L10:20";
235        let parsed = MentionUri::parse(symbol_uri).unwrap();
236        match &parsed {
237            MentionUri::Symbol {
238                path,
239                name,
240                line_range,
241            } => {
242                assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
243                assert_eq!(name, "MySymbol");
244                assert_eq!(line_range.start, 9);
245                assert_eq!(line_range.end, 19);
246            }
247            _ => panic!("Expected Symbol variant"),
248        }
249        assert_eq!(parsed.to_uri().to_string(), symbol_uri);
250    }
251
252    #[test]
253    fn test_parse_selection_uri() {
254        let selection_uri = "file:///path/to/file.rs#L5:15";
255        let parsed = MentionUri::parse(selection_uri).unwrap();
256        match &parsed {
257            MentionUri::Selection { path, line_range } => {
258                assert_eq!(path.to_str().unwrap(), "/path/to/file.rs");
259                assert_eq!(line_range.start, 4);
260                assert_eq!(line_range.end, 14);
261            }
262            _ => panic!("Expected Selection variant"),
263        }
264        assert_eq!(parsed.to_uri().to_string(), selection_uri);
265    }
266
267    #[test]
268    fn test_parse_thread_uri() {
269        let thread_uri = "zed:///agent/thread/session123?name=Thread%20name";
270        let parsed = MentionUri::parse(thread_uri).unwrap();
271        match &parsed {
272            MentionUri::Thread {
273                id: thread_id,
274                name,
275            } => {
276                assert_eq!(thread_id.to_string(), "session123");
277                assert_eq!(name, "Thread name");
278            }
279            _ => panic!("Expected Thread variant"),
280        }
281        assert_eq!(parsed.to_uri().to_string(), thread_uri);
282    }
283
284    #[test]
285    fn test_parse_rule_uri() {
286        let rule_uri = "zed:///agent/rule/d8694ff2-90d5-4b6f-be33-33c1763acd52?name=Some%20rule";
287        let parsed = MentionUri::parse(rule_uri).unwrap();
288        match &parsed {
289            MentionUri::Rule { id, name } => {
290                assert_eq!(id.to_string(), "d8694ff2-90d5-4b6f-be33-33c1763acd52");
291                assert_eq!(name, "Some rule");
292            }
293            _ => panic!("Expected Rule variant"),
294        }
295        assert_eq!(parsed.to_uri().to_string(), rule_uri);
296    }
297
298    #[test]
299    fn test_parse_fetch_http_uri() {
300        let http_uri = "http://example.com/path?query=value#fragment";
301        let parsed = MentionUri::parse(http_uri).unwrap();
302        match &parsed {
303            MentionUri::Fetch { url } => {
304                assert_eq!(url.to_string(), http_uri);
305            }
306            _ => panic!("Expected Fetch variant"),
307        }
308        assert_eq!(parsed.to_uri().to_string(), http_uri);
309    }
310
311    #[test]
312    fn test_parse_fetch_https_uri() {
313        let https_uri = "https://example.com/api/endpoint";
314        let parsed = MentionUri::parse(https_uri).unwrap();
315        match &parsed {
316            MentionUri::Fetch { url } => {
317                assert_eq!(url.to_string(), https_uri);
318            }
319            _ => panic!("Expected Fetch variant"),
320        }
321        assert_eq!(parsed.to_uri().to_string(), https_uri);
322    }
323
324    #[test]
325    fn test_invalid_scheme() {
326        assert!(MentionUri::parse("ftp://example.com").is_err());
327        assert!(MentionUri::parse("ssh://example.com").is_err());
328        assert!(MentionUri::parse("unknown://example.com").is_err());
329    }
330
331    #[test]
332    fn test_invalid_zed_path() {
333        assert!(MentionUri::parse("zed:///invalid/path").is_err());
334        assert!(MentionUri::parse("zed:///agent/unknown/test").is_err());
335    }
336
337    #[test]
338    fn test_invalid_line_range_format() {
339        // Missing L prefix
340        assert!(MentionUri::parse("file:///path/to/file.rs#10:20").is_err());
341
342        // Missing colon separator
343        assert!(MentionUri::parse("file:///path/to/file.rs#L1020").is_err());
344
345        // Invalid numbers
346        assert!(MentionUri::parse("file:///path/to/file.rs#L10:abc").is_err());
347        assert!(MentionUri::parse("file:///path/to/file.rs#Labc:20").is_err());
348    }
349
350    #[test]
351    fn test_invalid_query_parameters() {
352        // Invalid query parameter name
353        assert!(MentionUri::parse("file:///path/to/file.rs#L10:20?invalid=test").is_err());
354
355        // Too many query parameters
356        assert!(
357            MentionUri::parse("file:///path/to/file.rs#L10:20?symbol=test&another=param").is_err()
358        );
359    }
360
361    #[test]
362    fn test_zero_based_line_numbers() {
363        // Test that 0-based line numbers are rejected (should be 1-based)
364        assert!(MentionUri::parse("file:///path/to/file.rs#L0:10").is_err());
365        assert!(MentionUri::parse("file:///path/to/file.rs#L1:0").is_err());
366        assert!(MentionUri::parse("file:///path/to/file.rs#L0:0").is_err());
367    }
368}