Skip to content

Commit

Permalink
Merge pull request #21 from EGT-Ukraine/feature/expose_interface
Browse files Browse the repository at this point in the history
Expose generator interface.
  • Loading branch information
vektah authored Jan 26, 2019
2 parents 6f1fe26 + 3b61613 commit 29544c1
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 115 deletions.
35 changes: 27 additions & 8 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Gopkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
version = "1.2.1"

[[constraint]]
branch = "master"
revision = "66487607e2081c7c2af2281c62c14ee000d5024b"
name = "golang.org/x/tools"

[prune]
Expand Down
111 changes: 9 additions & 102 deletions dataloaden.go
Original file line number Diff line number Diff line change
@@ -1,126 +1,33 @@
package main

import (
"bytes"
"flag"
"fmt"
"go/build"
"io/ioutil"
"log"
"os"
"path/filepath"
"strings"
"unicode"

"golang.org/x/tools/imports"
"github.com/vektah/dataloaden/pkg/generator"
)

var (
keyType = flag.String("keys", "int", "what type should the keys be")
slice = flag.Bool("slice", false, "this dataloader will return slices")
)

type templateData struct {
LoaderName string
BatchName string
Package string
Name string
KeyType string
ValType string
Import string
Slice bool
}

func main() {
keyType := flag.String("keys", "int", "what type should the keys be")
slice := flag.Bool("slice", false, "this dataloader will return slices")

flag.Parse()

if flag.NArg() != 1 {
flag.Usage()
os.Exit(1)
}

data, err := getData(flag.Arg(0))
if err != nil {
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(2)
}

filename := data.Name + "loader_gen.go"
if data.Slice {
filename = data.Name + "sliceloader_gen.go"
}

writeTemplate(filename, data)
}

func getData(typeName string) (templateData, error) {
var data templateData
parts := strings.Split(typeName, ".")
if len(parts) < 2 {
return templateData{}, fmt.Errorf("type must be in the form package.Name")
}

wd, err := os.Getwd()
if err != nil {
return templateData{}, fmt.Errorf("cant determine working dir: %s", err.Error())
}

pkgData := getPackage(wd)
name := parts[len(parts)-1]
data.Package = pkgData
data.LoaderName = name + "Loader"
data.BatchName = lcFirst(name) + "Batch"
data.Name = lcFirst(name)
data.KeyType = *keyType
data.Slice = *slice

prefix := "*"
if *slice {
prefix = "[]"
data.LoaderName = name + "SliceLoader"
data.BatchName = lcFirst(name) + "SliceBatch"
}

// if we are inside the same package as the type we don't need an import and can refer directly to the type
pkgName := strings.Join(parts[:len(parts)-1], ".")
if strings.HasSuffix(filepath.ToSlash(wd), pkgName) {
data.ValType = prefix + name
} else {
data.Import = pkgName
data.ValType = prefix + filepath.Base(data.Import) + "." + name
}

return data, nil
}

func getPackage(wd string) string {
result, err := build.ImportDir(wd, build.IgnoreVendor)
if err != nil {
return filepath.Base(wd)
}

return result.Name
}

func writeTemplate(filename string, data templateData) {
var buf bytes.Buffer
if err := tpl.Execute(&buf, data); err != nil {
log.Fatalf("generating code: %v", err)
}

src, err := imports.Process(filename, buf.Bytes(), nil)
if err != nil {
log.Printf("unable to gofmt: %s", err.Error())
src = buf.Bytes()
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(2)
}

if err := ioutil.WriteFile(filename, src, 0644); err != nil {
log.Fatalf("writing output: %s", err)
if err := generator.Generate(flag.Arg(0), *keyType, *slice, wd); err != nil {
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(2)
}
}

func lcFirst(s string) string {
r := []rune(s)
r[0] = unicode.ToLower(r[0])
return string(r)
}
2 changes: 1 addition & 1 deletion example/pkgname/user.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package differentpkg

//go:generate go run ../../dataloaden.go ../../template.go -keys string github.com/vektah/dataloaden/example.User
//go:generate go run ../../dataloaden.go -keys string github.com/vektah/dataloaden/example.User
2 changes: 1 addition & 1 deletion example/slice/user.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//go:generate go run ../../dataloaden.go ../../template.go -keys int -slice github.com/vektah/dataloaden/example.User
//go:generate go run ../../dataloaden.go -keys int -slice github.com/vektah/dataloaden/example.User

package slice

Expand Down
2 changes: 1 addition & 1 deletion example/user.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//go:generate go run ../dataloaden.go ../template.go -keys string github.com/vektah/dataloaden/example.User
//go:generate go run ../dataloaden.go -keys string github.com/vektah/dataloaden/example.User

package example

Expand Down
111 changes: 111 additions & 0 deletions pkg/generator/generator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package generator

import (
"bytes"
"fmt"
"go/build"
"io/ioutil"
"path/filepath"
"strings"
"unicode"

"github.com/pkg/errors"
"golang.org/x/tools/imports"
)

type templateData struct {
LoaderName string
BatchName string
Package string
Name string
KeyType string
ValType string
Import string
Slice bool
}

func Generate(typename string, keyType string, slice bool, wd string) error {
data, err := getData(typename, keyType, slice, wd)
if err != nil {
return err
}

filename := data.Name + "loader_gen.go"
if data.Slice {
filename = data.Name + "sliceloader_gen.go"
}

if err := writeTemplate(filepath.Join(wd, filename), data); err != nil {
return err
}

return nil
}

func getData(typeName string, keyType string, slice bool, wd string) (templateData, error) {
var data templateData
parts := strings.Split(typeName, ".")
if len(parts) < 2 {
return templateData{}, fmt.Errorf("type must be in the form package.Name")
}

pkgData := getPackage(wd)
name := parts[len(parts)-1]
data.Package = pkgData
data.LoaderName = name + "Loader"
data.BatchName = lcFirst(name) + "Batch"
data.Name = lcFirst(name)
data.KeyType = keyType
data.Slice = slice

prefix := "*"
if slice {
prefix = "[]"
data.LoaderName = name + "SliceLoader"
data.BatchName = lcFirst(name) + "SliceBatch"
}

// if we are inside the same package as the type we don't need an import and can refer directly to the type
pkgName := strings.Join(parts[:len(parts)-1], ".")
if strings.HasSuffix(filepath.ToSlash(wd), pkgName) {
data.ValType = prefix + name
} else {
data.Import = pkgName
data.ValType = prefix + filepath.Base(data.Import) + "." + name
}

return data, nil
}

func getPackage(wd string) string {
result, err := build.ImportDir(wd, build.IgnoreVendor)
if err != nil {
return filepath.Base(wd)
}

return result.Name
}

func writeTemplate(filepath string, data templateData) error {
var buf bytes.Buffer
if err := tpl.Execute(&buf, data); err != nil {
return errors.Wrap(err, "generating code")
}

src, err := imports.Process(filepath, buf.Bytes(), nil)
if err != nil {
return errors.Wrap(err, "unable to gofmt")
}

if err := ioutil.WriteFile(filepath, src, 0644); err != nil {
return errors.Wrap(err, "writing output")
}

return nil
}

func lcFirst(s string) string {
r := []rune(s)
r[0] = unicode.ToLower(r[0])
return string(r)
}
2 changes: 1 addition & 1 deletion template.go → pkg/generator/template.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 29544c1

Please sign in to comment.