@@ -157,6 +157,8 @@ func (a *provider) LanguageModel(ctx context.Context, modelID string) (fantasy.L
)
}
if a.options.useBedrock {
+ modelID = bedrockPrefixModelWithRegion(modelID)
+
if a.options.skipAuth || a.options.apiKey != "" {
clientOptions = append(
clientOptions,
@@ -3,6 +3,7 @@ package anthropic
import (
"cmp"
"os"
+ "strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/smithy-go/auth/bearer"
@@ -14,3 +15,15 @@ func bedrockBasicAuthConfig(apiKey string) aws.Config {
BearerAuthTokenProvider: bearer.StaticTokenProvider{Token: bearer.Token{Value: apiKey}},
}
}
+
+func bedrockPrefixModelWithRegion(modelID string) string {
+ region := os.Getenv("AWS_REGION")
+ if len(region) < 2 {
+ region = "us-east-1"
+ }
+ prefix := region[:2] + "."
+ if strings.HasPrefix(modelID, prefix) {
+ return modelID
+ }
+ return prefix + modelID
+}
@@ -30,7 +30,7 @@ func builderBedrockClaude3Sonnet(t *testing.T, r *recorder.Recorder) (fantasy.La
if err != nil {
return nil, err
}
- return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-sonnet-20240229-v1:0")
+ return provider.LanguageModel(t.Context(), "anthropic.claude-3-sonnet-20240229-v1:0")
}
func builderBedrockClaude3Opus(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
@@ -41,7 +41,7 @@ func builderBedrockClaude3Opus(t *testing.T, r *recorder.Recorder) (fantasy.Lang
if err != nil {
return nil, err
}
- return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-opus-20240229-v1:0")
+ return provider.LanguageModel(t.Context(), "anthropic.claude-3-opus-20240229-v1:0")
}
func builderBedrockClaude3Haiku(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
@@ -52,7 +52,7 @@ func builderBedrockClaude3Haiku(t *testing.T, r *recorder.Recorder) (fantasy.Lan
if err != nil {
return nil, err
}
- return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-haiku-20240307-v1:0")
+ return provider.LanguageModel(t.Context(), "anthropic.claude-3-haiku-20240307-v1:0")
}
func buildersBedrockBasicAuth(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
@@ -64,5 +64,5 @@ func buildersBedrockBasicAuth(t *testing.T, r *recorder.Recorder) (fantasy.Langu
if err != nil {
return nil, err
}
- return provider.LanguageModel(t.Context(), "us.anthropic.claude-3-sonnet-20240229-v1:0")
+ return provider.LanguageModel(t.Context(), "anthropic.claude-3-sonnet-20240229-v1:0")
}