Change summary
internal/jsonext/json.go | 14 ++++++++++++++
providers/openai.go | 5 +++--
util.go | 8 --------
3 files changed, 17 insertions(+), 10 deletions(-)
Detailed changes
@@ -0,0 +1,14 @@
+package jsonext
+
+import (
+ "encoding/json"
+)
+
+func IsValidJSON[T string | []byte](data T) bool {
+ if len(data) == 0 { // hot path
+ return false
+ }
+ var m json.RawMessage
+ err := json.Unmarshal([]byte(data), &m)
+ return err == nil
+}
@@ -11,6 +11,7 @@ import (
"strings"
"github.com/charmbracelet/ai"
+ "github.com/charmbracelet/ai/internal/jsonext"
"github.com/google/uuid"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
@@ -618,7 +619,7 @@ func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea
return
}
toolCalls[toolCallDelta.Index] = existingToolCall
- if existingToolCall.arguments != "" && ai.IsParsableJSON(existingToolCall.arguments) {
+ if jsonext.IsValidJSON(existingToolCall.arguments) {
if !yield(ai.StreamPart{
Type: ai.StreamPartTypeToolInputEnd,
ID: existingToolCall.id,
@@ -679,7 +680,7 @@ func (o openAiLanguageModel) Stream(ctx context.Context, call ai.Call) (ai.Strea
}) {
return
}
- if ai.IsParsableJSON(toolCalls[toolCallDelta.Index].arguments) {
+ if jsonext.IsValidJSON(toolCalls[toolCallDelta.Index].arguments) {
if !yield(ai.StreamPart{
Type: ai.StreamPartTypeToolInputEnd,
ID: toolCallDelta.ID,
@@ -1,8 +1,6 @@
package ai
import (
- "encoding/json"
-
"github.com/go-viper/mapstructure/v2"
)
@@ -13,9 +11,3 @@ func ParseOptions[T any](options map[string]any, m *T) error {
func FloatOption(f float64) *float64 {
return &f
}
-
-func IsParsableJSON(data string) bool {
- var m map[string]any
- err := json.Unmarshal([]byte(data), &m)
- return err == nil
-}