1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "slices"
7
8 "github.com/charmbracelet/crush/internal/csync"
9)
10
11type ToolInfo struct {
12 Name string
13 Description string
14 Parameters map[string]any
15 Required []string
16}
17
18type toolResponseType string
19
20type (
21 sessionIDContextKey string
22 messageIDContextKey string
23)
24
25const (
26 ToolResponseTypeText toolResponseType = "text"
27 ToolResponseTypeImage toolResponseType = "image"
28
29 SessionIDContextKey sessionIDContextKey = "session_id"
30 MessageIDContextKey messageIDContextKey = "message_id"
31)
32
33type ToolResponse struct {
34 Type toolResponseType `json:"type"`
35 Content string `json:"content"`
36 Metadata string `json:"metadata,omitempty"`
37 IsError bool `json:"is_error"`
38}
39
40func NewTextResponse(content string) ToolResponse {
41 return ToolResponse{
42 Type: ToolResponseTypeText,
43 Content: content,
44 }
45}
46
47func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
48 if metadata != nil {
49 metadataBytes, err := json.Marshal(metadata)
50 if err != nil {
51 return response
52 }
53 response.Metadata = string(metadataBytes)
54 }
55 return response
56}
57
58func NewTextErrorResponse(content string) ToolResponse {
59 return ToolResponse{
60 Type: ToolResponseTypeText,
61 Content: content,
62 IsError: true,
63 }
64}
65
66type ToolCall struct {
67 ID string `json:"id"`
68 Name string `json:"name"`
69 Input string `json:"input"`
70}
71
72type BaseTool interface {
73 Info() ToolInfo
74 Name() string
75 Run(ctx context.Context, params ToolCall) (ToolResponse, error)
76}
77
78func GetContextValues(ctx context.Context) (string, string) {
79 sessionID := ctx.Value(SessionIDContextKey)
80 messageID := ctx.Value(MessageIDContextKey)
81 if sessionID == nil {
82 return "", ""
83 }
84 if messageID == nil {
85 return sessionID.(string), ""
86 }
87 return sessionID.(string), messageID.(string)
88}
89
90type Registry interface {
91 GetTool(name string) (BaseTool, bool)
92 SetTool(name string, tool BaseTool)
93 GetAllTools() []BaseTool
94}
95
96type registry struct {
97 tools *csync.LazySlice[BaseTool]
98}
99
100func (r *registry) GetAllTools() []BaseTool {
101 return slices.Collect(r.tools.Seq())
102}
103
104func (r *registry) GetTool(name string) (BaseTool, bool) {
105 for tool := range r.tools.Seq() {
106 if tool.Name() == name {
107 return tool, true
108 }
109 }
110
111 return nil, false
112}
113
114func (r *registry) SetTool(name string, tool BaseTool) {
115 for k, tool := range r.tools.Seq2() {
116 if tool.Name() == name {
117 r.tools.Set(k, tool)
118 return
119 }
120 }
121 r.tools.Append(tool)
122}
123
124type LazyToolsFn func() []BaseTool
125
126func NewRegistry(lazyTools LazyToolsFn) Registry {
127 return ®istry{
128 tools: csync.NewLazySlice(lazyTools),
129 }
130}
131
132func NewRegistryFromTools(tools []BaseTool) Registry {
133 return ®istry{
134 tools: csync.NewLazySlice(func() []BaseTool { return tools }),
135 }
136}