bedrock_test.go

 1package providertests
 2
 3import (
 4	"net/http"
 5	"os"
 6	"testing"
 7
 8	"charm.land/fantasy"
 9	"charm.land/fantasy/providers/bedrock"
10	"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
11)
12
13func TestBedrockCommon(t *testing.T) {
14	testCommon(t, []builderPair{
15		{"bedrock-anthropic-claude-3-sonnet", builderBedrockClaude3Sonnet, nil},
16		{"bedrock-anthropic-claude-3-opus", builderBedrockClaude3Opus, nil},
17		{"bedrock-anthropic-claude-3-haiku", builderBedrockClaude3Haiku, nil},
18	})
19}
20
21func TestBedrockBasicAuth(t *testing.T) {
22	testSimple(t, builderPair{"bedrock-anthropic-claude-3-sonnet", buildersBedrockBasicAuth, nil})
23}
24
25func builderBedrockClaude3Sonnet(r *recorder.Recorder) (fantasy.LanguageModel, error) {
26	provider := bedrock.New(
27		bedrock.WithHTTPClient(&http.Client{Transport: r}),
28		bedrock.WithSkipAuth(!r.IsRecording()),
29	)
30	return provider.LanguageModel("us.anthropic.claude-3-sonnet-20240229-v1:0")
31}
32
33func builderBedrockClaude3Opus(r *recorder.Recorder) (fantasy.LanguageModel, error) {
34	provider := bedrock.New(
35		bedrock.WithHTTPClient(&http.Client{Transport: r}),
36		bedrock.WithSkipAuth(!r.IsRecording()),
37	)
38	return provider.LanguageModel("us.anthropic.claude-3-opus-20240229-v1:0")
39}
40
41func builderBedrockClaude3Haiku(r *recorder.Recorder) (fantasy.LanguageModel, error) {
42	provider := bedrock.New(
43		bedrock.WithHTTPClient(&http.Client{Transport: r}),
44		bedrock.WithSkipAuth(!r.IsRecording()),
45	)
46	return provider.LanguageModel("us.anthropic.claude-3-haiku-20240307-v1:0")
47}
48
49func buildersBedrockBasicAuth(r *recorder.Recorder) (fantasy.LanguageModel, error) {
50	provider := bedrock.New(
51		bedrock.WithHTTPClient(&http.Client{Transport: r}),
52		bedrock.WithAPIKey(os.Getenv("FANTASY_BEDROCK_API_KEY")),
53		bedrock.WithSkipAuth(true),
54	)
55	return provider.LanguageModel("us.anthropic.claude-3-sonnet-20240229-v1:0")
56}