tools.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"slices"
  7
  8	"github.com/charmbracelet/crush/internal/csync"
  9)
 10
 11type ToolInfo struct {
 12	Name        string
 13	Description string
 14	Parameters  map[string]any
 15	Required    []string
 16}
 17
 18type toolResponseType string
 19
 20type (
 21	sessionIDContextKey string
 22	messageIDContextKey string
 23)
 24
 25const (
 26	ToolResponseTypeText  toolResponseType = "text"
 27	ToolResponseTypeImage toolResponseType = "image"
 28
 29	SessionIDContextKey sessionIDContextKey = "session_id"
 30	MessageIDContextKey messageIDContextKey = "message_id"
 31)
 32
 33type ToolResponse struct {
 34	Type     toolResponseType `json:"type"`
 35	Content  string           `json:"content"`
 36	Metadata string           `json:"metadata,omitempty"`
 37	IsError  bool             `json:"is_error"`
 38}
 39
 40func NewTextResponse(content string) ToolResponse {
 41	return ToolResponse{
 42		Type:    ToolResponseTypeText,
 43		Content: content,
 44	}
 45}
 46
 47func WithResponseMetadata(response ToolResponse, metadata any) ToolResponse {
 48	if metadata != nil {
 49		metadataBytes, err := json.Marshal(metadata)
 50		if err != nil {
 51			return response
 52		}
 53		response.Metadata = string(metadataBytes)
 54	}
 55	return response
 56}
 57
 58func NewTextErrorResponse(content string) ToolResponse {
 59	return ToolResponse{
 60		Type:    ToolResponseTypeText,
 61		Content: content,
 62		IsError: true,
 63	}
 64}
 65
 66type ToolCall struct {
 67	ID    string `json:"id"`
 68	Name  string `json:"name"`
 69	Input string `json:"input"`
 70}
 71
 72type BaseTool interface {
 73	Info() ToolInfo
 74	Name() string
 75	Run(ctx context.Context, params ToolCall) (ToolResponse, error)
 76}
 77
 78func GetContextValues(ctx context.Context) (string, string) {
 79	sessionID := ctx.Value(SessionIDContextKey)
 80	messageID := ctx.Value(MessageIDContextKey)
 81	if sessionID == nil {
 82		return "", ""
 83	}
 84	if messageID == nil {
 85		return sessionID.(string), ""
 86	}
 87	return sessionID.(string), messageID.(string)
 88}
 89
 90type Registry interface {
 91	GetTool(name string) (BaseTool, bool)
 92	SetTool(name string, tool BaseTool)
 93	GetAllTools() []BaseTool
 94}
 95
 96type registry struct {
 97	tools *csync.LazySlice[BaseTool]
 98}
 99
100func (r *registry) GetAllTools() []BaseTool {
101	return slices.Collect(r.tools.Seq())
102}
103
104func (r *registry) GetTool(name string) (BaseTool, bool) {
105	for tool := range r.tools.Seq() {
106		if tool.Name() == name {
107			return tool, true
108		}
109	}
110
111	return nil, false
112}
113
114func (r *registry) SetTool(name string, tool BaseTool) {
115	for k, tool := range r.tools.Seq2() {
116		if tool.Name() == name {
117			r.tools.Set(k, tool)
118			return
119		}
120	}
121	r.tools.Append(tool)
122}
123
124type LazyToolsFn func() []BaseTool
125
126func NewRegistry(lazyTools LazyToolsFn) Registry {
127	return &registry{
128		tools: csync.NewLazySlice(lazyTools),
129	}
130}
131
132func NewRegistryFromTools(tools []BaseTool) Registry {
133	return &registry{
134		tools: csync.NewLazySlice(func() []BaseTool { return tools }),
135	}
136}