Skip to content

Commit

Permalink
Load the corresponding hook when calling
Browse files Browse the repository at this point in the history
  • Loading branch information
aooohan committed Mar 25, 2024
1 parent 2869781 commit 3919d2b
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 71 deletions.
9 changes: 9 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
coverage:
status:
project:
default:
target: auto
# allow coverage to drop by this amount and still post success
threshold: 0.5%
if_ci_failed: error
patch: off # no github status notice for coverage of the PR diff.
16 changes: 8 additions & 8 deletions internal/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (m *Manager) EnvKeys() (*env.Envs, error) {

// LookupSdk lookup sdk by name
func (m *Manager) LookupSdk(name string) (*Sdk, error) {
pluginPath := filepath.Join(m.PathMeta.PluginPath, strings.ToLower(name), "main.lua")
pluginPath := filepath.Join(m.PathMeta.PluginPath, strings.ToLower(name))
if !util.FileExists(pluginPath) {
oldPath := filepath.Join(m.PathMeta.PluginPath, strings.ToLower(name)+".lua")
if !util.FileExists(oldPath) {
Expand All @@ -85,7 +85,7 @@ func (m *Manager) LookupSdk(name string) (*Sdk, error) {
if err != nil {
return nil, fmt.Errorf("failed to migrate an old plug-in: %w", err)
}
if err = os.Rename(oldPath, pluginPath); err != nil {
if err = os.Rename(oldPath, filepath.Join(pluginPath, "main.lua")); err != nil {
return nil, fmt.Errorf("failed to migrate an old plug-in: %w", err)
}
}
Expand Down Expand Up @@ -124,9 +124,9 @@ func (m *Manager) LoadAllSdk() (map[string]*Sdk, error) {
} else {
continue
}
source, err := NewLuaPlugin(path, m)
source, err := NewLuaPlugin(filepath.Dir(path), m)
if err != nil {
pterm.Printf("Failed to load %s plugin, err: %s\n", path, err)
pterm.Printf("Failed to load %s plugin, err: %s\n", filepath.Dir(path), err)
continue
}
sdk, _ := NewSdk(m, source)
Expand Down Expand Up @@ -182,16 +182,16 @@ func (m *Manager) Update(pluginName string) error {
return fmt.Errorf("check %s plugin failed, err: %w", updateUrl, err)
}
success := false
backupPath := sdk.Plugin.Filepath + ".bak"
err = util.CopyFile(sdk.Plugin.Filepath, backupPath)
backupPath := sdk.Plugin.Path + ".bak"
err = util.CopyFile(sdk.Plugin.Path, backupPath)
if err != nil {
return fmt.Errorf("backup %s plugin failed, err: %w", updateUrl, err)
}
defer func() {
if success {
_ = os.Remove(backupPath)
} else {
_ = os.Rename(backupPath, sdk.Plugin.Filepath)
_ = os.Rename(backupPath, sdk.Plugin.Path)
}
}()
pterm.Println("Checking plugin version...")
Expand All @@ -200,7 +200,7 @@ func (m *Manager) Update(pluginName string) error {
pterm.Printf("the plugin is already the latest version")
return nil
}
err = os.WriteFile(sdk.Plugin.Filepath, []byte(content), 0644)
err = os.WriteFile(sdk.Plugin.Path, []byte(content), 0644)
if err != nil {
return fmt.Errorf("update %s plugin failed: %w", updateUrl, err)
}
Expand Down
135 changes: 81 additions & 54 deletions internal/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,40 +43,35 @@ type HookFunc struct {
}

var (
HookFuncs = []HookFunc{
{Name: "Available", Required: true, Filename: "available"},
{Name: "PreInstall", Required: true, Filename: "pre_install"},
{Name: "EnvKeys", Required: true, Filename: "env_keys"},
{Name: "PostInstall", Required: false, Filename: "post_install"},
{Name: "PreUse", Required: false, Filename: "pre_use"},
// HookFuncMap is a map of built-in hook functions.
HookFuncMap = map[string]HookFunc{
"Available": {Name: "Available", Required: true, Filename: "available"},
"PreInstall": {Name: "PreInstall", Required: true, Filename: "pre_install"},
"EnvKeys": {Name: "EnvKeys", Required: true, Filename: "env_keys"},
"PostInstall": {Name: "PostInstall", Required: false, Filename: "post_install"},
"PreUse": {Name: "PreUse", Required: false, Filename: "pre_use"},
}
)

type LuaPlugin struct {
vm *luai.LuaVM
pluginObj *lua.LTable
// isLegacyMode indicates whether the plugin is in the old format, all in main.lua.
isLegacyMode bool
// plugin source path
Filepath string
Path string
// plugin filename, this is also alias name, sdk-name
SdkName string
// The name defined inside the plugin

LuaPluginInfo
}

func (l *LuaPlugin) checkValid() error {
if l.vm == nil || l.vm.Instance == nil {
return fmt.Errorf("lua vm is nil")
}

if !l.HasFunction("Available") {
return fmt.Errorf("[Available] function not found")
}
if !l.HasFunction("PreInstall") {
return fmt.Errorf("[PreInstall] function not found")
}
if !l.HasFunction("EnvKeys") {
return fmt.Errorf("[EnvKeys] function not found")
func (l *LuaPlugin) Validate() error {
for _, hf := range HookFuncMap {
if hf.Required {
if !l.HasFunction(hf.Name) {
return fmt.Errorf("[%s] function not found", hf.Name)
}
}
}
return nil
}
Expand All @@ -86,6 +81,9 @@ func (l *LuaPlugin) Close() {
}

func (l *LuaPlugin) Available() ([]*Package, error) {
if goon, err := l.loadHookFunc("Available"); err != nil || !goon {
return nil, err
}
L := l.vm.Instance
ctxTable, err := luai.Marshal(L, AvailableHookCtx{
RuntimeVersion: RuntimeVersion,
Expand Down Expand Up @@ -145,6 +143,9 @@ func (l *LuaPlugin) Available() ([]*Package, error) {
}

func (l *LuaPlugin) PreInstall(version Version) (*Package, error) {
if goon, err := l.loadHookFunc("PreInstall"); err != nil || !goon {
return nil, err
}
L := l.vm.Instance
ctxTable, err := luai.Marshal(L, PreInstallHookCtx{
Version: string(version),
Expand Down Expand Up @@ -194,11 +195,10 @@ func (l *LuaPlugin) PreInstall(version Version) (*Package, error) {
}

func (l *LuaPlugin) PostInstall(rootPath string, sdks []*Info) error {
L := l.vm.Instance

if !l.HasFunction("PostInstall") {
return nil
if goon, err := l.loadHookFunc("PostInstall"); err != nil || !goon {
return err
}
L := l.vm.Instance

ctx := &PostInstallHookCtx{
RuntimeVersion: RuntimeVersion,
Expand All @@ -223,6 +223,9 @@ func (l *LuaPlugin) PostInstall(rootPath string, sdks []*Info) error {
}

func (l *LuaPlugin) EnvKeys(sdkPackage *Package) (*env.Envs, error) {
if goon, err := l.loadHookFunc("EnvKeys"); err != nil || !goon {
return nil, err
}
L := l.vm.Instance
mainInfo := sdkPackage.Main

Expand Down Expand Up @@ -282,10 +285,18 @@ func (l *LuaPlugin) Label(version string) string {
}

func (l *LuaPlugin) HasFunction(name string) bool {
if !l.isLegacyMode {
if _, err := l.loadHookFunc(name); err != nil {
return false
}
}
return l.pluginObj.RawGetString(name) != lua.LNil
}

func (l *LuaPlugin) PreUse(version Version, previousVersion Version, scope UseScope, cwd string, installedSdks []*Package) (Version, error) {
if goon, err := l.loadHookFunc("PreUse"); err != nil || !goon {
return version, err
}
L := l.vm.Instance

ctx := PreUseHookCtx{
Expand All @@ -309,10 +320,6 @@ func (l *LuaPlugin) PreUse(version Version, previousVersion Version, scope UseSc
return "", err
}

if !l.HasFunction("PreUse") {
return "", nil
}

if err = l.CallFunction("PreUse", ctxTable); err != nil {
return "", err
}
Expand All @@ -331,6 +338,36 @@ func (l *LuaPlugin) PreUse(version Version, previousVersion Version, scope UseSc
return Version(result.Version), nil
}

// loadHookFunc loads the specified hook function.
// If the function is not built-in, it will return an error.
// If the function is not required and the file or the hook does not exist, it will return false
// and no error.
func (l *LuaPlugin) loadHookFunc(funcName string) (bool, error) {
if l.isLegacyMode {
return true, nil
}
hf, ok := HookFuncMap[funcName]
if !ok {
return false, fmt.Errorf("%s is not built-hook func", funcName)
}
hp := filepath.Join(l.Path, "hooks", hf.Filename+".lua")
logger.Debugf("Load [%s] function, from: %s\n", hf.Name, hp)
if !util.FileExists(hp) {
if hf.Required {
return false, fmt.Errorf("hook [%s] func not implemented", hf.Name)
}
return false, nil
}
if err := l.vm.Instance.DoFile(hp); err != nil {
return false, fmt.Errorf("failed to load [%s] hook function: %s", hf.Name, err.Error())
}
exist := l.pluginObj.RawGetString(hf.Name) != lua.LNil
if hf.Required && !exist {
return false, fmt.Errorf("hook [%s] func not implemented", hf.Name)
}
return true, nil
}

func (l *LuaPlugin) CallFunction(funcName string, args ...lua.LValue) error {
logger.Debugf("CallFunction: %s\n", funcName)
if err := l.vm.CallFunction(l.pluginObj.RawGetString(funcName), append([]lua.LValue{l.pluginObj}, args...)...); err != nil {
Expand Down Expand Up @@ -367,13 +404,14 @@ func NewLegacyLuaPlugin(content, path string, manager *Manager) (*LuaPlugin, err
PLUGIN := pluginObj.(*lua.LTable)

source := &LuaPlugin{
vm: vm,
pluginObj: PLUGIN,
Filepath: path,
SdkName: filepath.Base(filepath.Dir(path)),
vm: vm,
pluginObj: PLUGIN,
Path: path,
SdkName: filepath.Base(filepath.Dir(path)),
isLegacyMode: true,
}

if err := source.checkValid(); err != nil {
if err := source.Validate(); err != nil {
return nil, err
}

Expand Down Expand Up @@ -409,12 +447,14 @@ func NewLuaPlugin(pluginDirPath string, manager *Manager) (*LuaPlugin, error) {
}

mainPath := filepath.Join(pluginDirPath, "main.lua")
isLegacyMode := false
// main.lua first
if util.FileExists(mainPath) {
vm.LimitPackagePath(filepath.Join(pluginDirPath, "?.lua"))
if err := vm.Instance.DoFile(mainPath); err != nil {
return nil, err
}
isLegacyMode = true
} else {
// Limit package search scope, hooks directory search priority is higher than lib directory
hookPath := filepath.Join(pluginDirPath, "hooks", "?.lua")
Expand All @@ -431,17 +471,7 @@ func NewLuaPlugin(pluginDirPath string, manager *Manager) (*LuaPlugin, error) {
return nil, fmt.Errorf("failed to load meatadata file, %w", err)
}

// load hook func files
for _, hf := range HookFuncs {
hp := filepath.Join(pluginDirPath, "hooks", hf.Filename+".lua")

if !hf.Required && !util.FileExists(hp) {
continue
}
if err := vm.Instance.DoFile(hp); err != nil {
return nil, fmt.Errorf("failed to load [%s] hook function: %s", hf.Name, err.Error())
}
}
isLegacyMode = false
}

// !!!! Must be set after loading the script to prevent overwriting!
Expand All @@ -467,14 +497,11 @@ func NewLuaPlugin(pluginDirPath string, manager *Manager) (*LuaPlugin, error) {
PLUGIN := pluginObj.(*lua.LTable)

source := &LuaPlugin{
vm: vm,
pluginObj: PLUGIN,
Filepath: pluginDirPath,
SdkName: filepath.Base(pluginDirPath),
}

if err = source.checkValid(); err != nil {
return nil, err
vm: vm,
pluginObj: PLUGIN,
Path: pluginDirPath,
SdkName: filepath.Base(pluginDirPath),
isLegacyMode: isLegacyMode,
}

pluginInfo := LuaPluginInfo{}
Expand Down
13 changes: 4 additions & 9 deletions internal/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ func TestNewLuaPluginWithMain(t *testing.T) {
t.Errorf("expected filename 'java', got '%s'", plugin.SdkName)
}

if plugin.Filepath != pluginPathWithMain {
t.Errorf("expected filepath '%s', got '%s'", pluginPathWithMain, plugin.Filepath)
if plugin.Path != pluginPathWithMain {
t.Errorf("expected filepath '%s', got '%s'", pluginPathWithMain, plugin.Path)
}

if plugin.Name != "java" {
Expand Down Expand Up @@ -89,8 +89,8 @@ func TestNewLuaPluginWithMetadataAndHooks(t *testing.T) {
t.Errorf("expected filename 'java', got '%s'", plugin.SdkName)
}

if plugin.Filepath != pluginPathWithMetadata {
t.Errorf("expected filepath '%s', got '%s'", pluginPathWithMetadata, plugin.Filepath)
if plugin.Path != pluginPathWithMetadata {
t.Errorf("expected filepath '%s', got '%s'", pluginPathWithMetadata, plugin.Path)
}

if plugin.Name != "java" {
Expand All @@ -112,11 +112,6 @@ func TestNewLuaPluginWithMetadataAndHooks(t *testing.T) {
if plugin.MinRuntimeVersion != "0.2.2" {
t.Errorf("expected min runtime version '0.2.2', got '%s'", plugin.MinRuntimeVersion)
}
for _, hf := range HookFuncs {
if !plugin.HasFunction(hf.Name) && hf.Required {
t.Errorf("expected to have function %s", hf.Name)
}
}
})
testHookFunc(t, func() (*LuaPlugin, error) {
manager := NewSdkManager()
Expand Down

0 comments on commit 3919d2b

Please sign in to comment.