chat_with_functions.rs

  1//! This example creates a basic Chat UI with a function for rolling a die.
  2
  3use anyhow::{Context as _, Result};
  4use assets::Assets;
  5use assistant2::AssistantPanel;
  6use assistant_tooling::{LanguageModelTool, ToolRegistry};
  7use client::{Client, UserStore};
  8use fs::Fs;
  9use futures::StreamExt as _;
 10use gpui::{actions, AnyElement, App, AppContext, KeyBinding, Model, Task, View, WindowOptions};
 11use language::LanguageRegistry;
 12use project::Project;
 13use rand::Rng;
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use settings::{KeymapFile, DEFAULT_KEYMAP_PATH};
 17use std::{path::PathBuf, sync::Arc};
 18use theme::LoadThemes;
 19use ui::{div, prelude::*, Render};
 20use util::ResultExt as _;
 21
 22actions!(example, [Quit]);
 23
 24struct RollDiceTool {}
 25
 26impl RollDiceTool {
 27    fn new() -> Self {
 28        Self {}
 29    }
 30}
 31
 32#[derive(Serialize, Deserialize, JsonSchema, Clone)]
 33#[serde(rename_all = "snake_case")]
 34enum Die {
 35    D6 = 6,
 36    D20 = 20,
 37}
 38
 39impl Die {
 40    fn into_str(&self) -> &'static str {
 41        match self {
 42            Die::D6 => "d6",
 43            Die::D20 => "d20",
 44        }
 45    }
 46}
 47
 48#[derive(Serialize, Deserialize, JsonSchema, Clone)]
 49struct DiceParams {
 50    /// The number of dice to roll.
 51    num_dice: u8,
 52    /// Which die to roll. Defaults to a d6 if not provided.
 53    die_type: Option<Die>,
 54}
 55
 56#[derive(Serialize, Deserialize)]
 57struct DieRoll {
 58    die: Die,
 59    roll: u8,
 60}
 61
 62impl DieRoll {
 63    fn render(&self) -> AnyElement {
 64        match self.die {
 65            Die::D6 => {
 66                let face = match self.roll {
 67                    6 => div().child(""),
 68                    5 => div().child(""),
 69                    4 => div().child(""),
 70                    3 => div().child(""),
 71                    2 => div().child(""),
 72                    1 => div().child(""),
 73                    _ => div().child("😅"),
 74                };
 75                face.text_3xl().into_any_element()
 76            }
 77            _ => div()
 78                .child(format!("{}", self.roll))
 79                .text_3xl()
 80                .into_any_element(),
 81        }
 82    }
 83}
 84
 85#[derive(Serialize, Deserialize)]
 86struct DiceRoll {
 87    rolls: Vec<DieRoll>,
 88}
 89
 90pub struct DiceView {
 91    result: Result<DiceRoll>,
 92}
 93
 94impl Render for DiceView {
 95    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
 96        let output = match &self.result {
 97            Ok(output) => output,
 98            Err(_) => return "Somehow dice failed 🎲".into_any_element(),
 99        };
100
101        h_flex()
102            .children(
103                output
104                    .rolls
105                    .iter()
106                    .map(|roll| div().p_2().child(roll.render())),
107            )
108            .into_any_element()
109    }
110}
111
112impl LanguageModelTool for RollDiceTool {
113    type Input = DiceParams;
114    type Output = DiceRoll;
115    type View = DiceView;
116
117    fn name(&self) -> String {
118        "roll_dice".to_string()
119    }
120
121    fn description(&self) -> String {
122        "Rolls N many dice and returns the results.".to_string()
123    }
124
125    fn execute(
126        &self,
127        input: &Self::Input,
128        _cx: &mut WindowContext,
129    ) -> Task<gpui::Result<Self::Output>> {
130        let rolls = (0..input.num_dice)
131            .map(|_| {
132                let die_type = input.die_type.as_ref().unwrap_or(&Die::D6).clone();
133
134                DieRoll {
135                    die: die_type.clone(),
136                    roll: rand::thread_rng().gen_range(1..=die_type as u8),
137                }
138            })
139            .collect();
140
141        return Task::ready(Ok(DiceRoll { rolls }));
142    }
143
144    fn output_view(
145        _tool_call_id: String,
146        _input: Self::Input,
147        result: Result<Self::Output>,
148        cx: &mut WindowContext,
149    ) -> gpui::View<Self::View> {
150        cx.new_view(|_cx| DiceView { result })
151    }
152
153    fn format(_: &Self::Input, output: &Result<Self::Output>) -> String {
154        let output = match output {
155            Ok(output) => output,
156            Err(_) => return "Somehow dice failed 🎲".to_string(),
157        };
158
159        let mut result = String::new();
160        for roll in &output.rolls {
161            let die = &roll.die;
162            result.push_str(&format!("{}: {}\n", die.into_str(), roll.roll));
163        }
164        result
165    }
166}
167
168struct FileBrowserTool {
169    fs: Arc<dyn Fs>,
170    root_dir: PathBuf,
171}
172
173impl FileBrowserTool {
174    fn new(fs: Arc<dyn Fs>, root_dir: PathBuf) -> Self {
175        Self { fs, root_dir }
176    }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
180struct FileBrowserParams {
181    command: FileBrowserCommand,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
185enum FileBrowserCommand {
186    Ls { path: PathBuf },
187    Cat { path: PathBuf },
188}
189
190#[derive(Serialize, Deserialize)]
191enum FileBrowserOutput {
192    Ls { entries: Vec<String> },
193    Cat { content: String },
194}
195
196pub struct FileBrowserView {
197    result: Result<FileBrowserOutput>,
198}
199
200impl Render for FileBrowserView {
201    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
202        let Ok(output) = self.result.as_ref() else {
203            return h_flex().child("Failed to perform operation");
204        };
205
206        match output {
207            FileBrowserOutput::Ls { entries } => v_flex().children(
208                entries
209                    .into_iter()
210                    .map(|entry| h_flex().text_ui(cx).child(entry.clone())),
211            ),
212            FileBrowserOutput::Cat { content } => h_flex().child(content.clone()),
213        }
214    }
215}
216
217impl LanguageModelTool for FileBrowserTool {
218    type Input = FileBrowserParams;
219    type Output = FileBrowserOutput;
220    type View = FileBrowserView;
221
222    fn name(&self) -> String {
223        "file_browser".to_string()
224    }
225
226    fn description(&self) -> String {
227        "A tool for browsing the filesystem.".to_string()
228    }
229
230    fn execute(
231        &self,
232        input: &Self::Input,
233        cx: &mut WindowContext,
234    ) -> Task<gpui::Result<Self::Output>> {
235        cx.spawn({
236            let fs = self.fs.clone();
237            let root_dir = self.root_dir.clone();
238            let input = input.clone();
239            |_cx| async move {
240                match input.command {
241                    FileBrowserCommand::Ls { path } => {
242                        let path = root_dir.join(path);
243
244                        let mut output = fs.read_dir(&path).await?;
245
246                        let mut entries = Vec::new();
247                        while let Some(entry) = output.next().await {
248                            let entry = entry?;
249                            entries.push(entry.display().to_string());
250                        }
251
252                        Ok(FileBrowserOutput::Ls { entries })
253                    }
254                    FileBrowserCommand::Cat { path } => {
255                        let path = root_dir.join(path);
256
257                        let output = fs.load(&path).await?;
258
259                        Ok(FileBrowserOutput::Cat { content: output })
260                    }
261                }
262            }
263        })
264    }
265
266    fn output_view(
267        _tool_call_id: String,
268        _input: Self::Input,
269        result: Result<Self::Output>,
270        cx: &mut WindowContext,
271    ) -> gpui::View<Self::View> {
272        cx.new_view(|_cx| FileBrowserView { result })
273    }
274
275    fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
276        let Ok(output) = output else {
277            return "Failed to perform command: {input:?}".to_string();
278        };
279
280        match output {
281            FileBrowserOutput::Ls { entries } => entries.join("\n"),
282            FileBrowserOutput::Cat { content } => content.to_owned(),
283        }
284    }
285}
286
287fn main() {
288    env_logger::init();
289    App::new().with_assets(Assets).run(|cx| {
290        cx.bind_keys(Some(KeyBinding::new("cmd-q", Quit, None)));
291        cx.on_action(|_: &Quit, cx: &mut AppContext| {
292            cx.quit();
293        });
294
295        settings::init(cx);
296        language::init(cx);
297        Project::init_settings(cx);
298        editor::init(cx);
299        theme::init(LoadThemes::JustBase, cx);
300        Assets.load_fonts(cx).unwrap();
301        KeymapFile::load_asset(DEFAULT_KEYMAP_PATH, cx).unwrap();
302        client::init_settings(cx);
303        release_channel::init("0.130.0", cx);
304
305        let client = Client::production(cx);
306        {
307            let client = client.clone();
308            cx.spawn(|cx| async move { client.authenticate_and_connect(false, &cx).await })
309                .detach_and_log_err(cx);
310        }
311        assistant2::init(client.clone(), cx);
312
313        let language_registry = Arc::new(LanguageRegistry::new(
314            Task::ready(()),
315            cx.background_executor().clone(),
316        ));
317
318        let user_store = cx.new_model(|cx| UserStore::new(client.clone(), cx));
319        let node_runtime = node_runtime::RealNodeRuntime::new(client.http_client());
320        languages::init(language_registry.clone(), node_runtime, cx);
321
322        cx.spawn(|cx| async move {
323            cx.update(|cx| {
324                let fs = Arc::new(fs::RealFs::new(None));
325                let cwd = std::env::current_dir().expect("Failed to get current working directory");
326
327                cx.open_window(WindowOptions::default(), |cx| {
328                    let mut tool_registry = ToolRegistry::new();
329                    tool_registry
330                        .register(RollDiceTool::new(), cx)
331                        .context("failed to register DummyTool")
332                        .log_err();
333
334                    tool_registry
335                        .register(FileBrowserTool::new(fs, cwd), cx)
336                        .context("failed to register FileBrowserTool")
337                        .log_err();
338
339                    let tool_registry = Arc::new(tool_registry);
340
341                    println!("Tools registered");
342                    for definition in tool_registry.definitions() {
343                        println!("{}", definition);
344                    }
345
346                    cx.new_view(|cx| Example::new(language_registry, tool_registry, user_store, cx))
347                });
348                cx.activate(true);
349            })
350        })
351        .detach_and_log_err(cx);
352    })
353}
354
355struct Example {
356    assistant_panel: View<AssistantPanel>,
357}
358
359impl Example {
360    fn new(
361        language_registry: Arc<LanguageRegistry>,
362        tool_registry: Arc<ToolRegistry>,
363        user_store: Model<UserStore>,
364        cx: &mut ViewContext<Self>,
365    ) -> Self {
366        Self {
367            assistant_panel: cx.new_view(|cx| {
368                AssistantPanel::new(language_registry, tool_registry, user_store, None, cx)
369            }),
370        }
371    }
372}
373
374impl Render for Example {
375    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl ui::prelude::IntoElement {
376        div().size_full().child(self.assistant_panel.clone())
377    }
378}