fix: adjustments after rebase

Andrey Nering created

* Implement interface
* Fix options cast
* Make option structs public

Change summary

google/google.go           | 23 ++++++++++++++++++-----
google/provider_options.go | 14 +++++++++-----
2 files changed, 27 insertions(+), 10 deletions(-)

Detailed changes

google/google.go 🔗

@@ -69,6 +69,18 @@ func WithHTTPClient(client *http.Client) Option {
 	}
 }
 
+func (*provider) Name() string {
+	return Name
+}
+
+func (a *provider) ParseOptions(data map[string]any) (ai.ProviderOptionsData, error) {
+	var options ProviderOptions
+	if err := ai.ParseOptions(data, &options); err != nil {
+		return nil, err
+	}
+	return &options, nil
+}
+
 type languageModel struct {
 	provider        string
 	modelID         string
@@ -97,11 +109,12 @@ func (g *provider) LanguageModel(modelID string) (ai.LanguageModel, error) {
 
 func (a languageModel) prepareParams(call ai.Call) (*genai.GenerateContentConfig, []*genai.Content, []ai.CallWarning, error) {
 	config := &genai.GenerateContentConfig{}
-	providerOptions := &providerOptions{}
-	if v, ok := call.ProviderOptions["google"]; ok {
-		err := ai.ParseOptions(v, providerOptions)
-		if err != nil {
-			return nil, nil, nil, err
+
+	providerOptions := &ProviderOptions{}
+	if v, ok := call.ProviderOptions[Name]; ok {
+		providerOptions, ok = v.(*ProviderOptions)
+		if !ok {
+			return nil, nil, nil, ai.NewInvalidArgumentError("providerOptions", "anthropic provider options should be *anthropic.ProviderOptions", nil)
 		}
 	}
 

google/provider_options.go 🔗

@@ -1,11 +1,13 @@
 package google
 
-type thinkingConfig struct {
+const Name = "google"
+
+type ThinkingConfig struct {
 	ThinkingBudget  *int64 `json:"thinking_budget"`
 	IncludeThoughts *bool  `json:"include_thoughts"`
 }
 
-type safetySetting struct {
+type SafetySetting struct {
 	// 'HARM_CATEGORY_UNSPECIFIED',
 	// 'HARM_CATEGORY_HATE_SPEECH',
 	// 'HARM_CATEGORY_DANGEROUS_CONTENT',
@@ -22,8 +24,8 @@ type safetySetting struct {
 	// 'OFF',
 	Threshold string `json:"threshold"`
 }
-type providerOptions struct {
-	ThinkingConfig *thinkingConfig `json:"thinking_config"`
+type ProviderOptions struct {
+	ThinkingConfig *ThinkingConfig `json:"thinking_config"`
 
 	// Optional.
 	// The name of the cached content used as context to serve the prediction.
@@ -31,7 +33,7 @@ type providerOptions struct {
 	CachedContent string `json:"cached_content"`
 
 	// Optional. A list of unique safety settings for blocking unsafe content.
-	SafetySettings []safetySetting `json:"safety_settings"`
+	SafetySettings []SafetySetting `json:"safety_settings"`
 	// 'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
 	// 'BLOCK_LOW_AND_ABOVE',
 	// 'BLOCK_MEDIUM_AND_ABOVE',
@@ -40,3 +42,5 @@ type providerOptions struct {
 	// 'OFF',
 	Threshold string `json:"threshold"`
 }
+
+func (o *ProviderOptions) Options() {}