diff --git a/depstubber.go b/depstubber.go index 1b0c046..d9907b0 100644 --- a/depstubber.go +++ b/depstubber.go @@ -6,7 +6,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "log" "os" "path/filepath" @@ -153,7 +152,7 @@ func createStubs(packageName string, typeNames []string, funcAndVarNames []strin g.srcFunctions = strings.Join(funcAndVarNames, ",") if *copyrightFile != "" { - header, err := ioutil.ReadFile(*copyrightFile) + header, err := os.ReadFile(*copyrightFile) if err != nil { log.Fatalf("Failed reading copyright file: %v", err) } diff --git a/reflect.go b/reflect.go index 6ad6087..f29b8df 100644 --- a/reflect.go +++ b/reflect.go @@ -8,7 +8,6 @@ import ( "flag" "fmt" "go/build" - "io/ioutil" "log" "os" "os/exec" @@ -44,7 +43,7 @@ func writeProgram(importPath string, types []string, values []string) ([]byte, e // run the given program and parse the output as a model.Package. func run(program string) (*model.PackedPkg, error) { - f, err := ioutil.TempFile("", "") + f, err := os.CreateTemp("", "") if err != nil { return nil, err } @@ -85,7 +84,7 @@ func run(program string) (*model.PackedPkg, error) { // parses the output as a model.Package. func runInDir(program []byte, dir string) (*model.PackedPkg, error) { // We use TempDir instead of TempFile so we can control the filename. - tmpDir, err := ioutil.TempDir(dir, "depstubber_reflect_") + tmpDir, err := os.MkdirTemp(dir, "depstubber_reflect_") if err != nil { return nil, err } @@ -101,7 +100,7 @@ func runInDir(program []byte, dir string) (*model.PackedPkg, error) { progBinary += ".exe" } - if err := ioutil.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil { + if err := os.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil { return nil, err } diff --git a/util.go b/util.go index 6d9fe33..8de60cc 100644 --- a/util.go +++ b/util.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "os" "runtime/debug" @@ -23,7 +22,7 @@ func removeDot(s string) string { // packageNameOfDir get package import path via dir func packageNameOfDir(srcDir string) (string, error) { - files, err := ioutil.ReadDir(srcDir) + files, err := os.ReadDir(srcDir) if err != nil { log.Fatal(err) } diff --git a/vendor.go b/vendor.go index 4c437d9..a9bde9d 100644 --- a/vendor.go +++ b/vendor.go @@ -4,10 +4,13 @@ package main import ( "bytes" - "io/ioutil" + "go/parser" + "go/token" "log" "os" "path/filepath" + "regexp" + "sort" "strings" "golang.org/x/mod/modfile" @@ -38,7 +41,7 @@ func findModuleRoot(dir string) (root string) { } func loadModFile(filename string) *modfile.File { - data, err := ioutil.ReadFile(filename) + data, err := os.ReadFile(filename) if err != nil { panic(err) } @@ -85,24 +88,30 @@ func stubModulesTxt() { } modFile := loadModFile(filepath.Join(modRoot, "go.mod")) - vdir := filepath.Join(modRoot, "vendor") if gv := modFile.Go; gv != nil && semver.Compare("v"+gv.Version, "v1.14") >= 0 { - // If the Go version is at least 1.14, generate a dummy modules.txt using only the information - // in the go.mod file + // Find imports from all Go files in the project + usedPackages := findPackagesInSourceCode(modRoot) generated := make(map[module.Version]bool) var buf bytes.Buffer for _, r := range modFile.Require { - // TODO: support replace lines generated[r.Mod] = true line := moduleLine(r.Mod, module.Version{}) buf.WriteString(line) - buf.WriteString("## explicit\n") - buf.WriteString(r.Mod.Path + "\n") + // List package paths that are used in the source code + packagesForModule := findPackagesForModule(r.Mod.Path, usedPackages) + if len(packagesForModule) > 0 { + for _, pkg := range packagesForModule { + buf.WriteString(pkg + "\n") + } + } else { + // If we can't find any packages then just list the module path itself + buf.WriteString(r.Mod.Path + "\n") + } } // Record unused and wildcard replacements at the end of the modules.txt file: @@ -128,8 +137,74 @@ func stubModulesTxt() { log.Fatalf("go mod vendor: %v", err) } - if err := ioutil.WriteFile(filepath.Join(vdir, "modules.txt"), buf.Bytes(), 0666); err != nil { + if err := os.WriteFile(filepath.Join(vdir, "modules.txt"), buf.Bytes(), 0666); err != nil { log.Fatalf("go mod vendor: %v", err) } } } + +// findPackagesInSourceCode scans all Go files in the directory tree and extracts import paths +func findPackagesInSourceCode(root string) map[string]bool { + packages := make(map[string]bool) + + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // Skip vendor directory and hidden directories + if info.IsDir() && (info.Name() == "vendor" || strings.HasPrefix(info.Name(), ".")) { + return filepath.SkipDir + } + + // Only process Go files + if !info.IsDir() && strings.HasSuffix(path, ".go") { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, path, nil, parser.ImportsOnly) + if err != nil { + return err + } + + // Extract import paths from the AST + for _, imp := range file.Imports { + pkgPath := strings.Trim(imp.Path.Value, "\"") + packages[pkgPath] = true + } + } + return nil + }) + + if err != nil { + log.Printf("Warning: error walking source directory: %v", err) + } + + return packages +} + +// Compile the regular expression once +var majorVersionSuffixRegex = regexp.MustCompile(`^/v[1-9][0-9]*(/|$)`) + +// findPackagesForModule returns the submodules of a given module that are actually used in the source code +func findPackagesForModule(modulePath string, usedPackages map[string]bool) []string { + var packages []string + + for pkg := range usedPackages { + // Check if this package belongs to the module + if strings.HasPrefix(pkg, modulePath) { + // Extract the part after modulePath + suffix := pkg[len(modulePath):] + + // If `suffix` begins with a major version suffix then we do not have the right module + // path. For example, if the module path is `example.com/mymodule` and the package path + // is `example.com/mymodule/v2/submodule` then we should not consider it a match - it + // is really a match for the module `example.com/mymodule/v2`. + if !majorVersionSuffixRegex.MatchString(suffix) { + packages = append(packages, pkg) + } + } + } + + // Sort packages for consistent output + sort.Strings(packages) + return packages +}