Skip to content

Commit

Permalink
Merge pull request #10 from ikawaha/develop
Browse files Browse the repository at this point in the history
Release candidate
  • Loading branch information
ikawaha authored Feb 26, 2022
2 parents 84ec131 + 71d0af8 commit 6630031
Show file tree
Hide file tree
Showing 10 changed files with 295 additions and 189 deletions.
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,41 @@ waifu2x.go is a clone of waifu2x-js.

waifu2x-js: https://github.com/takuyaa/waifu2x-js

Changes
---
* 2022-02-09: Imported changes from [go-waifu2x](https://github.com/puhitaku/go-waifu2x), a fork of this repository. This is an excellent job done by @puhitaku and @orisano. It is 14x faster than the original in the non-GPU case.

Install
---

```shell
go install github.com/ikawaha/waifu2x.go@latest
```

Changes
Usage
---
* 2022-02-09: Imported changes from [go-waifu2x](https://github.com/puhitaku/go-waifu2x), a fork of this repository. This is an excellent job done by @puhitaku and @orisano. It is 14x faster than the original in the non-GPU case.

```shell
$ waifu2x.go --help
Usage of waifu2x:
-i string
input file (default stdin)
-m string
waifu2x mode, choose from 'anime' and 'photo' (default "anime")
-n int
noise reduction level 0 <= n <= 3
-o string
output file (default stdout)
-p int
concurrency (default 8)
-s float
scale multiplier >= 1.0 (default 2)
-v verbose
```

<img width="542" alt="image" src="https://user-images.githubusercontent.com/4232165/155845021-83a90df6-5324-4511-94fc-2d9d4a00273c.png">

The Go gopher was designed by [Renée French](https://reneefrench.blogspot.com/).

Note
---
Expand Down
139 changes: 71 additions & 68 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"bytes"
"context"
"flag"
"fmt"
Expand All @@ -10,50 +11,45 @@ import (
"io"
"os"
"runtime"
"strings"

"github.com/ikawaha/waifu2x.go/engine"
)

const (
commandName = "waifu2x"
usageMessage = "%s (-i|--input) <input_file> [-o|--output <output_file>] [-s|--scale <scale_factor>] [-j|--jobs <n>] [-n|--noise <n>] [-m|--mode (anime|photo)]\n"
)

const (
modeAnime = "anime"
modePhoto = "photo"
)

type option struct {
input string
output string
scale float64
jobs int
noiseReduction int
mode string
flagSet *flag.FlagSet
// flagSet args
input string
output string
scale float64
noise int
parallel int
modeStr string
verbose bool

// option values
mode engine.Mode
flagSet *flag.FlagSet
}

const commandName = `waifu2x`

func newOption(w io.Writer, eh flag.ErrorHandling) (o *option) {
o = &option{
flagSet: flag.NewFlagSet(commandName, eh),
}
// option settings
o.flagSet.SetOutput(w)
o.flagSet.StringVar(&o.input, "i", "", "input file (short)")
o.flagSet.StringVar(&o.input, "input", "", "input file")
o.flagSet.StringVar(&o.output, "o", "", "output file (short) (default stdout)")
o.flagSet.StringVar(&o.output, "output", "", "output file (default stdout)")
o.flagSet.Float64Var(&o.scale, "s", 2.0, "scale multiplier >= 1.0 (short)")
o.flagSet.Float64Var(&o.scale, "scale", 2.0, "scale multiplier >= 1.0")
o.flagSet.IntVar(&o.jobs, "j", runtime.NumCPU(), "# of goroutines (short)")
o.flagSet.IntVar(&o.jobs, "jobs", runtime.NumCPU(), "# of goroutines")
o.flagSet.IntVar(&o.noiseReduction, "n", 0, "noise reduction level 0 <= n <= 3 (short)")
o.flagSet.IntVar(&o.noiseReduction, "noise", 0, "noise reduction level 0 <= n <= 3")
o.flagSet.StringVar(&o.mode, "m", modeAnime, "waifu2x mode, choose from 'anime' and 'photo' (short) (default anime)")
o.flagSet.StringVar(&o.mode, "mode", modeAnime, "waifu2x mode, choose from 'anime' and 'photo' (default anime)")

o.flagSet.StringVar(&o.input, "i", "", "input file (default stdin)")
o.flagSet.StringVar(&o.output, "o", "", "output file (default stdout)")
o.flagSet.Float64Var(&o.scale, "s", 2.0, "scale multiplier >= 1.0")
o.flagSet.IntVar(&o.noise, "n", 0, "noise reduction level 0 <= n <= 3")
o.flagSet.IntVar(&o.parallel, "p", runtime.GOMAXPROCS(runtime.NumCPU()), "concurrency")
o.flagSet.StringVar(&o.modeStr, "m", modeAnime, "waifu2x mode, choose from 'anime' and 'photo'")
o.flagSet.BoolVar(&o.verbose, "v", false, "verbose")
return
}

Expand All @@ -65,72 +61,79 @@ func (o *option) parse(args []string) error {
if nonFlag := o.flagSet.Args(); len(nonFlag) != 0 {
return fmt.Errorf("invalid argument: %v", nonFlag)
}
if o.input == "" {
return fmt.Errorf("input file is empty")
}
if o.scale < 1.0 {
return fmt.Errorf("invalid scale, %v > 1", o.scale)
}
if o.jobs < 1 {
return fmt.Errorf("invalid number of jobs, %v < 1", o.jobs)
if o.noise < 0 || o.noise > 3 {
return fmt.Errorf("invalid number of noise reduction level, it must be [0,3]")
}
if o.noiseReduction < 0 || o.noiseReduction > 3 {
return fmt.Errorf("invalid number of noise reduction level, it must be 0 - 3")
if o.parallel < 1 {
return fmt.Errorf("invalid number of parallel, it must be >= 1")
}
if o.mode != modeAnime && o.mode != modePhoto {
switch o.modeStr {
case modeAnime:
o.mode = engine.Anime
case modePhoto:
o.mode = engine.Photo
default:
return fmt.Errorf("invalid mode, choose from 'anime' or 'photo'")
}
return nil
}

// Usage shows a usage message.
func Usage() {
fmt.Printf(usageMessage, commandName)
opt := newOption(os.Stdout, flag.ContinueOnError)
opt.flagSet.PrintDefaults()
func parseInputImage(file string) (image.Image, error) {
var b []byte
in := os.Stdin
if file != "" {
var err error
b, err = os.ReadFile(file)
if err != nil {
return nil, err
}
} else {
var err error
b, err = io.ReadAll(in)
if err != nil {
return nil, err
}
}
_, format, err := image.DecodeConfig(bytes.NewReader(b))
if err != nil {
return nil, err
}
var decoder func(io.Reader) (image.Image, error)
switch format {
case "jpeg":
decoder = jpeg.Decode
case "png":
decoder = png.Decode
default:
return nil, fmt.Errorf("unsupported image type: %s", format)
}
return decoder(bytes.NewReader(b))
}

// Run executes the waifu2x command.
func Run(args []string) error {
opt := newOption(os.Stderr, flag.ContinueOnError)
opt := newOption(os.Stderr, flag.ExitOnError)
if err := opt.parse(args); err != nil {
return err
}

fp, err := os.Open(opt.input)
img, err := parseInputImage(opt.input)
if err != nil {
return fmt.Errorf("input file %v, %w", opt.input, err)
return fmt.Errorf("input error: %w", err)
}
defer fp.Close()

var img image.Image
if strings.HasSuffix(fp.Name(), "jpg") || strings.HasSuffix(fp.Name(), "jpeg") {
img, err = jpeg.Decode(fp)
if err != nil {
return fmt.Errorf("load file %v, %w", opt.input, err)
}
} else if strings.HasSuffix(fp.Name(), "png") {
img, err = png.Decode(fp)
if err != nil {
return fmt.Errorf("load file %v, %w", opt.input, err)
}
}

mode := engine.Anime
switch opt.mode {
case "anime":
mode = engine.Anime
case "photo":
mode = engine.Photo
}
w2x, err := engine.NewWaifu2x(mode, opt.noiseReduction, engine.Parallel(8), engine.Verbose())
w2x, err := engine.NewWaifu2x(opt.mode, opt.noise, []engine.Option{
engine.Verbose(opt.verbose),
engine.Parallel(opt.parallel),
}...)
if err != nil {
return err
}

rgba, err := w2x.ScaleUp(context.TODO(), img, opt.scale)
if err != nil {
return fmt.Errorf("calc error: %w", err)
return err
}

var w io.Writer = os.Stdout
Expand All @@ -143,7 +146,7 @@ func Run(args []string) error {
w = fp
}
if err := png.Encode(w, &rgba); err != nil {
panic(err)
return fmt.Errorf("output error: %w", err)
}
return nil
}
2 changes: 1 addition & 1 deletion engine/channel_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func NewChannelImageWidthHeight(width, height int) ChannelImage {
return ChannelImage{
Width: width,
Height: height,
Buffer: make([]uint8, width*height), // XXX 0以下を0, 255以上を255 として登録する必要あり
Buffer: make([]uint8, width*height), // note. it is necessary to register all values less than 0 as 0 and greater than 255 as 255
}
}

Expand Down
28 changes: 28 additions & 0 deletions engine/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Param struct {
Weight [][][][]float64 `json:"weight"` // 重み
NInputPlane int `json:"nInputPlane"` // 入力平面数
NOutputPlane int `json:"nOutputPlane"` // 出力平面数
WeightVec []float64
}

// Model represents a trained model.
Expand All @@ -38,6 +39,7 @@ func LoadModel(r io.Reader) (Model, error) {
if err := dec.Decode(&m); err != nil {
return nil, err
}
m.setWeightVec()
return m, nil
}

Expand Down Expand Up @@ -122,3 +124,29 @@ func NewAssetModelSet(t Mode, noiseLevel int) (*ModelSet, error) {
NoiseModel: noise,
}, nil
}

func (m Model) setWeightVec() {
for l := range m {
param := m[l]
// [nOutputPlane][nInputPlane][3][3]
const square = 9
vec := make([]float64, param.NInputPlane*param.NOutputPlane*9)
for i := 0; i < param.NInputPlane; i++ {
for o := 0; o < param.NOutputPlane; o++ {
offset := i*param.NOutputPlane*square + o*square
vec[offset+0] = param.Weight[o][i][0][0]
vec[offset+1] = param.Weight[o][i][0][1]
vec[offset+2] = param.Weight[o][i][0][2]

vec[offset+3] = param.Weight[o][i][1][0]
vec[offset+4] = param.Weight[o][i][1][1]
vec[offset+5] = param.Weight[o][i][1][2]

vec[offset+6] = param.Weight[o][i][2][0]
vec[offset+7] = param.Weight[o][i][2][1]
vec[offset+8] = param.Weight[o][i][2][2]
}
}
m[l].WeightVec = vec
}
}
36 changes: 36 additions & 0 deletions engine/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,39 @@ func TestLoadModel(t *testing.T) {
}
}
}

func Test_setWeightVec(t *testing.T) {
model, err := LoadModelFile("./model/anime_style_art/scale2.0x_model.json")
if err != nil {
t.Fatalf("unexpected error, %v", err)
}
matrix := typeW(model)
model.setWeightVec()
for i, param := range model {
for j, v := range param.WeightVec {
if matrix[i][j] != v {
t.Fatalf("[%d, %d]=%v <> %v", i, j, matrix[i][j], v)
}
}
}
}

// W[][O*I*9]
func typeW(model Model) [][]float64 {
var W [][]float64
for l := range model {
// initialize weight matrix
param := model[l]
var vec []float64
// [nOutputPlane][nInputPlane][3][3]
for i := 0; i < param.NInputPlane; i++ {
for o := 0; o < param.NOutputPlane; o++ {
vec = append(vec, param.Weight[o][i][0]...)
vec = append(vec, param.Weight[o][i][1]...)
vec = append(vec, param.Weight[o][i][2]...)
}
}
W = append(W, vec)
}
return W
}
Loading

0 comments on commit 6630031

Please sign in to comment.