Skip to content

Commit

Permalink
Fix test client setup and result interface tests (#44)
Browse files Browse the repository at this point in the history
* fix client test

* override transport if custom headers exists

* fix primitive type failing test

* fix results tests

* cleanup
  • Loading branch information
NRHelmi authored Nov 15, 2022
1 parent d388926 commit 3e6dfd8
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 142 deletions.
66 changes: 38 additions & 28 deletions rai/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func isErrNotFound(err error) bool {

// Ensure that the test engine exists.
func ensureEngine(client *Client, engine, size string) error {
fmt.Printf("using engine: %s\n", engine)
if _, err := client.GetEngine(engine); err != nil {
if !isErrNotFound(err) {
return err
Expand All @@ -61,6 +62,7 @@ func ensureEngine(client *Client, engine, size string) error {

// Ensure the test database exists.
func ensureDatabase(client *Client, database string) error {
fmt.Printf("using database: %s\n", database)
if _, err := client.GetDatabase(database); err != nil {
if !isErrNotFound(err) {
return err
Expand Down Expand Up @@ -91,21 +93,25 @@ func (h headerRoundTrip) RoundTrip(r *http.Request) (*http.Response, error) {
// available.
func newTestClient() (*Client, error) {
configPath, _ := expandUser(DefaultConfigFile)
var testClient *Client
if _, err := os.Stat(configPath); err == nil {
return NewDefaultClient()
}

var cfg Config
testClient, err = NewDefaultClient()
if err != nil {
panic(err)
}

clientId := os.Getenv("CLIENT_ID")
clientSecret := os.Getenv("CLIENT_SECRET")
clientCredentialsUrl := os.Getenv("CLIENT_CREDENTIALS_URL")
raiHost := os.Getenv("HOST")
if raiHost == "" {
raiHost = "azure.relationalai.com"
}
} else {
var cfg Config

clientId := os.Getenv("CLIENT_ID")
clientSecret := os.Getenv("CLIENT_SECRET")
clientCredentialsUrl := os.Getenv("CLIENT_CREDENTIALS_URL")
raiHost := os.Getenv("HOST")
if raiHost == "" {
raiHost = "azure.relationalai.com"
}

placeHolderConfig := `
configFormat := `
[default]
host=%s
region=us-east
Expand All @@ -114,26 +120,30 @@ func newTestClient() (*Client, error) {
client_id=%s
client_secret=%s
client_credentials_url=%s
`
configSrc := fmt.Sprintf(placeHolderConfig, raiHost, clientId, clientSecret, clientCredentialsUrl)
LoadConfigString(configSrc, "default", &cfg)
opts := ClientOptions{Config: cfg}
testClient := NewClient(context.Background(), &opts)
`
configSrc := fmt.Sprintf(configFormat, raiHost, clientId, clientSecret, clientCredentialsUrl)
LoadConfigString(configSrc, "default", &cfg)
opts := ClientOptions{Config: cfg}
testClient = NewClient(context.Background(), &opts)
}

// get custom headers
var customHeaders map[string]string
json.Unmarshal([]byte(os.Getenv("CUSTOM_HEADERS")), &customHeaders)
if err := json.Unmarshal([]byte(os.Getenv("CUSTOM_HEADERS")), &customHeaders); err == nil {
fmt.Printf("using custom headers: %s\n", customHeaders)

// override default http client roundTrip
var defaultTransport http.RoundTripper
if testClient.HttpClient.Transport == nil {
defaultTransport = http.DefaultTransport
} else {
defaultTransport = testClient.HttpClient.Transport
}

// override default http client roundTrip
var defaultTransport http.RoundTripper
if testClient.HttpClient.Transport == nil {
defaultTransport = http.DefaultTransport
} else {
defaultTransport = testClient.HttpClient.Transport
}
testClient.HttpClient.Transport = headerRoundTrip{
defaultTransport,
customHeaders,
testClient.HttpClient.Transport = headerRoundTrip{
defaultTransport,
customHeaders,
}
}

return testClient, nil
Expand Down
Loading

0 comments on commit 3e6dfd8

Please sign in to comment.