fix(bedrock): prefix model id with region automatically

Andrey Nering created

Change summary

providers/anthropic/anthropic.go |  2 ++
providers/anthropic/bedrock.go   | 13 +++++++++++++
providertests/bedrock_test.go    |  8 ++++----
3 files changed, 19 insertions(+), 4 deletions(-)

Detailed changes

providers/anthropic/anthropic.go 🔗

@@ -157,6 +157,8 @@ func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.L
 		)
 	}
 	if a.options.useBedrock {
+		modelID = bedrockPrefixModelWithRegion(modelID)
+
 		if a.options.skipAuth || a.options.apiKey != "" {
 			clientOptions = append(
 				clientOptions,

providers/anthropic/bedrock.go 🔗

@@ -3,6 +3,7 @@ package anthropic
 import (
 	"cmp"
 	"os"
+	"strings"
 
 	"github.com/aws/aws-sdk-go-v2/aws"
 	"github.com/aws/smithy-go/auth/bearer"
@@ -14,3 +15,15 @@ func bedrockBasicAuthConfig(apiKey string) aws.Config {
 		BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
 	}
 }
+
+func bedrockPrefixModelWithRegion(modelID string) string {
+	region := os.Getenv("AWS_REGION")
+	if len(region) < 2 {
+		region = "us-east-1"
+	}
+	prefix := region[:2] + "."
+	if strings.HasPrefix(modelID, prefix) {
+		return modelID
+	}
+	return prefix + modelID
+}

providertests/bedrock_test.go 🔗

@@ -30,7 +30,7 @@ func builderBedrockClaude3Sonnet(t *testing.T, r *recorder.Recorder) (fantasy.La
 	if err != nil {
 		return nil, err
 	}
-	return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-sonnet-20240229-v1:0")
+	return provider.LanguageModel(t.Context(), "anthropic.claude-3-sonnet-20240229-v1:0")
 }
 
 func builderBedrockClaude3Opus(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
@@ -41,7 +41,7 @@ func builderBedrockClaude3Opus(t *testing.T, r *recorder.Recorder) (fantasy.Lang
 	if err != nil {
 		return nil, err
 	}
-	return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-opus-20240229-v1:0")
+	return provider.LanguageModel(t.Context(), "anthropic.claude-3-opus-20240229-v1:0")
 }
 
 func builderBedrockClaude3Haiku(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
@@ -52,7 +52,7 @@ func builderBedrockClaude3Haiku(t *testing.T, r *recorder.Recorder) (fantasy.Lan
 	if err != nil {
 		return nil, err
 	}
-	return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-haiku-20240307-v1:0")
+	return provider.LanguageModel(t.Context(), "anthropic.claude-3-haiku-20240307-v1:0")
 }
 
 func buildersBedrockBasicAuth(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
@@ -64,5 +64,5 @@ func buildersBedrockBasicAuth(t *testing.T, r *recorder.Recorder) (fantasy.Langu
 	if err != nil {
 		return nil, err
 	}
-	return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-sonnet-20240229-v1:0")
+	return provider.LanguageModel(t.Context(), "anthropic.claude-3-sonnet-20240229-v1:0")
 }