diff --git a/internal/jsonext/json.go b/internal/jsonext/json.go new file mode 100644 index 0000000000000000000000000000000000000000..467f22e2cb478acef6ace9bb3246b956168523dd --- /dev/null +++ b/internal/jsonext/json.go @@ -0,0 +1,14 @@ +package jsonext + +import ( + "encoding/json" +) + +func IsValidJSON[T string | []byte](data T) bool { + if len(data) == 0 { // hot path + return false + } + var m json.RawMessage + err := json.Unmarshal([]byte(data), &m) + return err == nil +} diff --git a/providers/openai.go b/providers/openai.go index c9d396e4343cac1fcf263aee22231d7de91f56ae..e3d5594c964da67a22511d02a3b364da5bcfb81f 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/charmbracelet/ai" + "github.com/charmbracelet/ai/internal/jsonext" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" @@ -618,7 +619,7 @@ func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea return } toolCalls[toolCallDelta.Index] = existingToolCall - if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) { + if jsonext.IsValidJSON(existingToolCall.arguments) { if !yield(ai.StreamPart{ Type: ai.StreamPartTypeToolInputEnd, ID: existingToolCall.id, @@ -679,7 +680,7 @@ func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea }) { return } - if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) { + if jsonext.IsValidJSON(toolCalls[toolCallDelta.Index].arguments) { if !yield(ai.StreamPart{ Type: ai.StreamPartTypeToolInputEnd, ID: toolCallDelta.ID, diff --git a/util.go b/util.go index 6f0012d66d132a5810a58c7f8f8bede59cb41956..dc50b4c6c27a59fa4a261f4132cbb8d0d3b4c1b3 100644 --- a/util.go +++ b/util.go @@ -1,8 +1,6 @@ package ai import ( - "encoding/json" - "github.com/go-viper/mapstructure/v2" ) @@ -13,9 +11,3 @@ func ParseOptions[T any](options map[string]any, m *T) error { func FloatOption(f float64) *float64 { return &f } - -func IsParsableJSON(data string) bool { - var m map[string]any - err := json.Unmarshal([]byte(data), &m) - return err == nil -}