refactor: json check func should not be part of the public api

Andrey Nering created

Change summary

internal/jsonext/json.go | 14 ++++++++++++++
providers/openai.go      |  5 +++--
util.go                  |  8 --------
3 files changed, 17 insertions(+), 10 deletions(-)

Detailed changes

internal/jsonext/json.go 🔗

@@ -0,0 +1,14 @@
+package jsonext
+
+import (
+	"encoding/json"
+)
+
+func IsValidJSON[T string | []byte](data T) bool {
+	if len(data) == 0 { // hot path
+		return false
+	}
+	var m json.RawMessage
+	err := json.Unmarshal([]byte(data), &m)
+	return err == nil
+}

providers/openai.go 🔗

@@ -11,6 +11,7 @@ import (
 	"strings"
 
 	"github.com/charmbracelet/ai"
+	"github.com/charmbracelet/ai/internal/jsonext"
 	"github.com/google/uuid"
 	"github.com/openai/openai-go/v2"
 	"github.com/openai/openai-go/v2/option"
@@ -618,7 +619,7 @@ func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea
 								return
 							}
 							toolCalls[toolCallDelta.Index] = existingToolCall
-							if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
+							if jsonext.IsValidJSON(existingToolCall.arguments) {
 								if !yield(ai.StreamPart{
 									Type: ai.StreamPartTypeToolInputEnd,
 									ID:   existingToolCall.id,
@@ -679,7 +680,7 @@ func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea
 								}) {
 									return
 								}
-								if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
+								if jsonext.IsValidJSON(toolCalls[toolCallDelta.Index].arguments) {
 									if !yield(ai.StreamPart{
 										Type: ai.StreamPartTypeToolInputEnd,
 										ID:   toolCallDelta.ID,

util.go 🔗

@@ -1,8 +1,6 @@
 package ai
 
 import (
-	"encoding/json"
-
 	"github.com/go-viper/mapstructure/v2"
 )
 
@@ -13,9 +11,3 @@ func ParseOptions[T any](options map[string]any, m *T) error {
 func FloatOption(f float64) *float64 {
 	return &f
 }
-
-func IsParsableJSON(data string) bool {
-	var m map[string]any
-	err := json.Unmarshal([]byte(data), &m)
-	return err == nil
-}