-
Notifications
You must be signed in to change notification settings - Fork 0
/
request.go
123 lines (105 loc) · 3.5 KB
/
request.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package main
import (
"sync"
"strings"
"path/filepath"
"net/http"
"time"
"os"
"fmt"
"errors"
"io/fs"
"syscall"
)
func chooseLockPool(path string, num_pools int) int {
sum := 0
for _, r := range path {
sum += int(r)
}
return sum % num_pools
}
// This tracks the requests that are currently being processed, to prevent the
// same request being processed multiple times at the same time. We use a
// multi-pool approach to improve parallelism across requests.
type activeRequestRegistry struct {
NumPools int
Locks []sync.Mutex
Active []map[string]bool
}
func newActiveRequestRegistry(num_pools int) *activeRequestRegistry {
return &activeRequestRegistry {
NumPools: num_pools,
Locks: make([]sync.Mutex, num_pools),
Active: make([]map[string]bool, num_pools),
}
}
func prefillActiveRequestRegistry(a *activeRequestRegistry, staging string, expiry time.Duration) error {
// Prefilling the registry ensures that a user can't replay requests after a restart of the service.
entries, err := os.ReadDir(staging)
if err != nil {
return fmt.Errorf("failed to list existing request files in '%s'", staging)
}
// This is only necessary until the expiry time is exceeded, after which we can evict those entries.
// Technically we only need to do this for files that weren't already expired, but this doesn't hurt.
for _, e := range entries {
path := e.Name()
a.Add(path)
go func(p string) {
time.Sleep(expiry)
a.Remove(p)
}(path)
}
return nil
}
func (a *activeRequestRegistry) Add(path string) bool {
i := chooseLockPool(path, a.NumPools)
a.Locks[i].Lock()
defer a.Locks[i].Unlock()
if a.Active[i] == nil {
a.Active[i] = map[string]bool{}
} else {
_, ok := a.Active[i][path]
if ok {
return false
}
}
a.Active[i][path] = true
return true
}
func (a *activeRequestRegistry) Remove(path string) {
i := chooseLockPool(path, a.NumPools)
a.Locks[i].Lock()
defer a.Locks[i].Unlock()
delete(a.Active[i], path)
}
func checkRequestFile(path, staging string, expiry time.Duration) (string, error) {
if !strings.HasPrefix(path, "request-") {
return "", newHttpError(http.StatusBadRequest, errors.New("file name should start with \"request-\""))
}
if path != filepath.Base(path) {
return "", newHttpError(http.StatusBadRequest, errors.New("path should be the name of a file in the staging directory"))
}
reqpath := filepath.Join(staging, path)
info, err := os.Lstat(reqpath)
if err != nil {
return "", newHttpError(http.StatusBadRequest, fmt.Errorf("failed to access path; %v", err))
}
if info.IsDir() {
return "", newHttpError(http.StatusBadRequest, errors.New("path is a directory"))
}
if info.Mode() & fs.ModeSymlink != 0 {
return "", newHttpError(http.StatusBadRequest, errors.New("path is a symbolic link"))
}
s, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return "", fmt.Errorf("failed to convert to a syscall.Stat_t; %w", err)
}
if uint32(s.Nlink) > 1 {
return "", newHttpError(http.StatusBadRequest, errors.New("path seems to have multiple hard links"))
}
current := time.Now()
if current.Sub(info.ModTime()) >= expiry {
return "", newHttpError(http.StatusBadRequest, errors.New("request file is expired"))
}
return reqpath, nil
}