From c2412e4aaa6c8d44c157530896a8691dd9aa37a4 Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Tue, 21 Nov 2017 17:49:36 +0900 Subject: [PATCH] Use channel to feed filenames --- file_checker.go | 4 ++-- main.go | 39 ++++----------------------------------- utils.go | 41 +++++++++++++++++++++++++++++++++-------- utils_test.go | 10 ++++++---- 4 files changed, 45 insertions(+), 49 deletions(-) diff --git a/file_checker.go b/file_checker.go index 3b08546..fe642d7 100644 --- a/file_checker.go +++ b/file_checker.go @@ -39,10 +39,10 @@ func (c fileChecker) Check(f string) ([]urlResult, error) { return rs, nil } -func (c fileChecker) CheckMany(fs []string, rc chan<- fileResult) { +func (c fileChecker) CheckMany(fc <-chan string, rc chan<- fileResult) { wg := sync.WaitGroup{} - for _, f := range fs { + for f := range fc { wg.Add(1) go func(f string) { diff --git a/main.go b/main.go index 5c647f2..7bf8806 100644 --- a/main.go +++ b/main.go @@ -10,20 +10,15 @@ func main() { os.Exit(1) } - if args.recursive { - args.filenames, err = listFilesRecursively(args.filenames) + fc := make(chan string, 1024) - if err != nil { - printToStderr(err.Error()) - os.Exit(1) - } - } + go findMarkupFiles(args.filenames, args.recursive, fc) - rc := make(chan fileResult, len(args.filenames)) + rc := make(chan fileResult, 1024) s := newSemaphore(args.concurrency) c := newFileChecker(args.timeout, s) - go c.CheckMany(args.filenames, rc) + go c.CheckMany(fc, rc) ok := true @@ -40,29 +35,3 @@ func main() { os.Exit(1) } } - -func listFilesRecursively(fs []string) ([]string, error) { - gs := []string{} - - for _, f := range fs { - i, err := os.Stat(f) - - if err != nil { - return nil, err - } - - if i.IsDir() { - fs, err := listFiles(f) - - if err != nil { - return nil, err - } - - gs = append(gs, fs...) - } else { - gs = append(gs, f) - } - } - - return gs, nil -} diff --git a/utils.go b/utils.go index e9adbad..4e1f8b3 100644 --- a/utils.go +++ b/utils.go @@ -5,7 +5,9 @@ import ( "os" "path/filepath" "regexp" + "strings" + "github.com/fatih/color" "github.com/kr/text" ) @@ -23,14 +25,18 @@ func printToStderr(xs ...interface{}) { fmt.Fprintln(os.Stderr, xs...) } +func fail(err error) { + s := err.Error() + printToStderr(color.RedString(strings.ToUpper(s[:1]) + s[1:])) + os.Exit(1) +} + func indent(s string) string { return text.Indent(s, "\t") } -func listFiles(d string) ([]string, error) { - fs := []string{} - - err := filepath.Walk(d, func(f string, i os.FileInfo, err error) error { +func listDirectory(d string, fc chan<- string) error { + return filepath.Walk(d, func(f string, i os.FileInfo, err error) error { if err != nil { return err } @@ -42,15 +48,34 @@ func listFiles(d string) ([]string, error) { } if !i.IsDir() && !b && isMarkupFile(f) { - fs = append(fs, f) + fc <- f } return nil }) +} + +func findMarkupFiles(fs []string, recursive bool, fc chan<- string) { + for _, f := range fs { + i, err := os.Stat(f) - if err != nil { - return nil, err + if err != nil { + fail(err) + } + + if i.IsDir() && recursive { + err := listDirectory(f, fc) + + if err != nil { + fail(err) + } + + } else if i.IsDir() { + fail(fmt.Errorf("%v is not a file", f)) + } else { + fc <- f + } } - return fs, nil + close(fc) } diff --git a/utils_test.go b/utils_test.go index 25e0d1c..330c53b 100644 --- a/utils_test.go +++ b/utils_test.go @@ -7,13 +7,15 @@ import ( "github.com/stretchr/testify/assert" ) -func TestListFiles(t *testing.T) { - fs, err := listFiles(".") +func TestListDirectory(t *testing.T) { + fc := make(chan string, 1024) + err := listDirectory(".", fc) + close(fc) assert.Equal(t, nil, err) - assert.NotEqual(t, 0, len(fs)) + assert.NotEqual(t, 0, len(fc)) - for _, f := range fs { + for f := range fc { i, err := os.Stat(f) assert.True(t, isMarkupFile(f))