Skip to content

Commit

Permalink
Create consistent workunit IDs across all nodes (#1194)
Browse files Browse the repository at this point in the history
  • Loading branch information
matoval authored Nov 12, 2024
1 parent 5b6f448 commit 25ff7aa
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/source/user_guide/configuration_options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Node
- Default value
- Type
* - ``id``
- Node ID
- Node ID can only contain a-z, A-Z, 0-9 or special characters . - _ @
- local hostname
- string
* - ``datadir``
Expand Down
2 changes: 1 addition & 1 deletion pkg/controlsvc/controlsvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ func (s *Server) RunControlSession(conn net.Conn) {
}
}
} else {
writeMsg := "ERROR: Unknown command\n"
writeMsg := fmt.Sprintf("ERROR: Unknown command, %v\n", cmd)
if writeToConnWithLog(conn, s.nc, writeMsg, writeControlServiceError) {
return
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/types/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"regexp"
"strings"

"github.com/ansible/receptor/pkg/controlsvc"
Expand Down Expand Up @@ -34,6 +35,12 @@ func (cfg NodeCfg) Init() error {
return fmt.Errorf("no node ID specified and local host name is localhost")
}
cfg.ID = host
} else {
submitIDRegex := regexp.MustCompile(`^[.\-_@a-zA-Z0-9]*$`)
match := submitIDRegex.FindSubmatch([]byte(cfg.ID))
if match == nil {
return fmt.Errorf("node id can only contain a-z, A-Z, 0-9 or special characters . - _ @ but received: %s", cfg.ID)
}
}
if strings.ToLower(cfg.ID) == "localhost" {
return fmt.Errorf("node ID \"localhost\" is reserved")
Expand Down
8 changes: 6 additions & 2 deletions pkg/workceptor/controlsvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ func (c *workceptorCommand) ControlFunc(ctx context.Context, nc controlsvc.Netce
if err != nil {
signature = ""
}
workUnitID, err := strFromMap(c.params, "workUnitID")
if err != nil {
workUnitID = ""
}
workParams := make(map[string]string)
nonParams := []string{"command", "subcommand", "node", "worktype", "tlsclient", "ttl", "signwork", "signature"}
inNonParams := func(p string) bool {
Expand Down Expand Up @@ -283,9 +287,9 @@ func (c *workceptorCommand) ControlFunc(ctx context.Context, nc controlsvc.Netce
if ttl != "" {
return nil, fmt.Errorf("ttl option is intended for remote work only")
}
worker, err = c.w.AllocateUnit(workType, workParams)
worker, err = c.w.AllocateUnit(workType, workUnitID, workParams)
} else {
worker, err = c.w.AllocateRemoteUnit(workNode, workType, tlsClient, ttl, signWork, workParams)
worker, err = c.w.AllocateRemoteUnit(workNode, workType, workUnitID, tlsClient, ttl, signWork, workParams)
}
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/workceptor/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestWorkceptorJson(t *testing.T) {
if err != nil {
t.Fatal(err)
}
cw, err := w.AllocateUnit("command", make(map[string]string))
cw, err := w.AllocateUnit("command", "", make(map[string]string))
if err != nil {
t.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/workceptor/remote_work.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader
for k, v := range red.RemoteParams {
workSubmitCmd[k] = v
}
workSubmitCmd["workUnitID"] = rw.ID()
workSubmitCmd["command"] = "work"
workSubmitCmd["subcommand"] = "submit"
workSubmitCmd["node"] = red.RemoteNode
Expand Down Expand Up @@ -183,7 +184,7 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader
if err != nil {
return fmt.Errorf("read error reading from %s: %s", red.RemoteNode, err)
}
submitIDRegex := regexp.MustCompile(`with ID ([a-zA-Z0-9]+)\.`)
submitIDRegex := regexp.MustCompile(`with ID ([.\-_@a-zA-Z0-9]+)\.`)
match := submitIDRegex.FindSubmatch([]byte(response))
if match == nil || len(match) != 2 {
return fmt.Errorf("could not parse response: %s", strings.TrimRight(response, "\n"))
Expand Down
23 changes: 17 additions & 6 deletions pkg/workceptor/workceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,25 @@ func (w *Workceptor) RegisterWorker(typeName string, newWorkerFunc NewWorkerFunc
return nil
}

func (w *Workceptor) generateUnitID(lock bool) (string, error) {
func (w *Workceptor) generateUnitID(lock bool, workUnitID string) (string, error) {
if lock {
w.activeUnitsLock.RLock()
defer w.activeUnitsLock.RUnlock()
}
var ident string
for {
ident = randstr.RandomString(8)
if workUnitID == "" {
rstr := randstr.RandomString(8)
nid := w.nc.NodeID()
ident = fmt.Sprintf("%s%s", nid, rstr)
} else {
ident = workUnitID
unitdir := path.Join(w.dataDir, ident)
_, err := os.Stat(unitdir)
if err == nil {
return "", fmt.Errorf("workunit ID %s is already in use, cannot use the same workunit ID more than once", ident)
}
}
_, ok := w.activeUnits[ident]
if !ok {
unitdir := path.Join(w.dataDir, ident)
Expand Down Expand Up @@ -243,7 +254,7 @@ func (w *Workceptor) VerifySignature(signature string) error {
}

// AllocateUnit creates a new local work unit and generates an identifier for it.
func (w *Workceptor) AllocateUnit(workTypeName string, params map[string]string) (WorkUnit, error) {
func (w *Workceptor) AllocateUnit(workTypeName string, workUnitID string, params map[string]string) (WorkUnit, error) {
w.workTypesLock.RLock()
wt, ok := w.workTypes[workTypeName]
w.workTypesLock.RUnlock()
Expand All @@ -252,7 +263,7 @@ func (w *Workceptor) AllocateUnit(workTypeName string, params map[string]string)
}
w.activeUnitsLock.Lock()
defer w.activeUnitsLock.Unlock()
ident, err := w.generateUnitID(false)
ident, err := w.generateUnitID(false, workUnitID)
if err != nil {
return nil, err
}
Expand All @@ -270,7 +281,7 @@ func (w *Workceptor) AllocateUnit(workTypeName string, params map[string]string)
}

// AllocateRemoteUnit creates a new remote work unit and generates a local identifier for it.
func (w *Workceptor) AllocateRemoteUnit(remoteNode, remoteWorkType, tlsClient, ttl string, signWork bool, params map[string]string) (WorkUnit, error) {
func (w *Workceptor) AllocateRemoteUnit(remoteNode, remoteWorkType, workUnitID string, tlsClient, ttl string, signWork bool, params map[string]string) (WorkUnit, error) {
if tlsClient != "" {
_, err := w.nc.GetClientTLSConfig(tlsClient, "testhost", netceptor.ExpectedHostnameTypeReceptor)
if err != nil {
Expand All @@ -288,7 +299,7 @@ func (w *Workceptor) AllocateRemoteUnit(remoteNode, remoteWorkType, tlsClient, t
if hasSecrets && tlsClient == "" {
return nil, fmt.Errorf("cannot send secrets over a non-TLS connection")
}
rw, err := w.AllocateUnit("remote", params)
rw, err := w.AllocateUnit("remote", workUnitID, params)
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/workceptor/workceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func testSetup(t *testing.T) (*gomock.Controller, *mock_workceptor.MockNetceptor

ctx := context.Background()
mockNetceptor := mock_workceptor.NewMockNetceptorForWorkceptor(ctrl)
mockNetceptor.EXPECT().NodeID().Return("test")
mockNetceptor.EXPECT().NodeID().Return("test").AnyTimes()

logger := logger.NewReceptorLogger("")
mockNetceptor.EXPECT().GetLogger().AnyTimes().Return(logger)
Expand Down Expand Up @@ -48,7 +48,7 @@ func TestAllocateUnit(t *testing.T) {
return mockWorkUnit
}

mockNetceptor.EXPECT().NodeID().Return("test")
mockNetceptor.EXPECT().NodeID().Return("test").Times(4)
w, err := workceptor.New(ctx, mockNetceptor, "/tmp")
if err != nil {
t.Errorf("Error while creating Workceptor: %v", err)
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestAllocateUnit(t *testing.T) {
mockWorkUnit.EXPECT().Save().Return(tc.saveError).Times(1)
}

_, err := w.AllocateUnit(tc.workType, map[string]string{"param": "value"})
_, err := w.AllocateUnit(tc.workType, "", map[string]string{"param": "value"})
checkError(err, tc.expectedError, t)
})
}
Expand Down Expand Up @@ -195,7 +195,7 @@ func TestRegisterWorker(t *testing.T) {
hasError: false,
expectedCalls: func() {
mockNetceptor.EXPECT().AddWorkCommand(gomock.Any(), gomock.Any())
w.AllocateUnit("remote", map[string]string{})
w.AllocateUnit("remote", "", map[string]string{})
},
},
}
Expand Down Expand Up @@ -350,7 +350,7 @@ func TestAllocateRemoteUnit(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.expectedCalls()
_, err := w.AllocateRemoteUnit("", "", tc.tlsClient, tc.ttl, tc.signWork, tc.params)
_, err := w.AllocateRemoteUnit("", "", "", tc.tlsClient, tc.ttl, tc.signWork, tc.params)

if tc.errorMsg != "" && tc.errorMsg != err.Error() && err != nil {
t.Errorf("expected: %s, received: %s", tc.errorMsg, err)
Expand Down

0 comments on commit 25ff7aa

Please sign in to comment.