-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
service.go
1172 lines (1023 loc) · 37.9 KB
/
service.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
package waf
import (
"bytes"
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"io/fs"
"log"
"mime"
"net/http"
"net/http/httputil"
"net/url"
"path/filepath"
"reflect"
"slices"
"strconv"
"strings"
"time"
mapset "github.com/deckarep/golang-set/v2"
"github.com/hashicorp/go-cleanhttp"
"github.com/justinas/alice"
"github.com/rs/cors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"gitlab.com/tozd/go/errors"
"gitlab.com/tozd/go/x"
z "gitlab.com/tozd/go/zerolog"
)
// CORSOptions is a subset of cors.Options.
//
// See description of fields in cors.Options.
//
// See: https://github.com/rs/cors/pull/164
type CORSOptions struct {
AllowedOrigins []string `json:"allowedOrigins,omitempty"`
AllowedMethods []string `json:"allowedMethods,omitempty"`
AllowedHeaders []string `json:"allowedHeaders,omitempty"`
ExposedHeaders []string `json:"exposedHeaders,omitempty"`
MaxAge int `json:"maxAge,omitempty"`
AllowCredentials bool `json:"allowCredentials,omitempty"`
AllowPrivateNetwork bool `json:"allowPrivateNetwork,omitempty"`
OptionsSuccessStatus int `json:"optionsSuccessStatus,omitempty"`
}
func (c *CORSOptions) GetAllowedMethods() []string {
if len(c.AllowedMethods) == 0 {
// We allow only GET and HEAD by default.
// This is different from the cors package which also has POST.
return []string{http.MethodGet, http.MethodHead}
}
allowedMethods := []string{}
hasGet := false
hasHead := false
for _, method := range c.AllowedMethods {
method = strings.ToUpper(method)
allowedMethods = append(allowedMethods, method)
if method == http.MethodGet {
hasGet = true
} else if method == http.MethodHead {
hasHead = true
}
}
if hasGet && !hasHead {
allowedMethods = append(allowedMethods, http.MethodHead)
}
return allowedMethods
}
type RouteOptions struct {
// Enable CORS on handler(s)?
CORS *CORSOptions `json:"cors,omitempty"`
}
// Route is a high-level route definition which is used by a service
// to register handlers with the router. It can also be used by Vue Router
// to register routes there.
type Route struct {
// Name of the route. It should be unique.
Name string `json:"name"`
// Path for the route. It can contain parameters.
Path string `json:"path"`
// Does this route support API handlers.
// API paths are automatically prefixed with /api.
API *RouteOptions `json:"api,omitempty"`
// Does this route have a non-API handler.
Get *RouteOptions `json:"get,omitempty"`
}
type staticFile struct {
Data []byte
Etag string
MediaType string
}
// Site describes the site at a domain.
//
// A service can have multiple sites which share static files and handlers,
// but have different configuration and rendered HTML files. Core
// such configuration is site's domain, but you can provide your own
// site struct and embed Site to add additional configuration.
// Your site struct is then used when rendering HTML files and
// as site context to the frontend at SiteContextPath URL path.
//
// Certificate and key file paths are not exposed in site context JSON.
type Site struct {
Domain string `json:"domain" required:"" yaml:"domain"`
// Certificate file path for the site. It should be valid for the domain.
// Used when Let's Encrypt is not configured.
CertFile string `help:"Certificate for TLS, when not using Let's Encrypt." json:"-" name:"cert" placeholder:"PATH" type:"existingfile" yaml:"cert,omitempty"`
// Key file path. Used when Let's Encrypt is not configured.
KeyFile string `help:"Certificate's private key, when not using Let's Encrypt." json:"-" name:"key" placeholder:"PATH" type:"existingfile" yaml:"key,omitempty"`
// Maps between content types, paths, and data/etag/media type.
// They are per site because they can include rendered per-site data.
// File contents are deduplicated between sites if they are the same.
staticFiles map[string]map[string]staticFile
}
func (s *Site) Validate() error {
if s.CertFile != "" || s.KeyFile != "" {
if s.CertFile == "" {
return errors.Errorf(`missing file certificate for provided private key for site "%s"`, s.Domain)
}
if s.KeyFile == "" {
return errors.Errorf(`missing file certificate's matching private key for site "%s"`, s.Domain)
}
}
return nil
}
// GetSite returns Site. This is used when you want to provide your own
// site struct to access the Site struct. If you embed Site inside your
// site struct then this method propagates to your site struct and does
// the right thing automatically.
func (s *Site) GetSite() *Site {
return s
}
func (s *Site) initializeStaticFiles() {
s.staticFiles = make(map[string]map[string]staticFile)
for _, compression := range allCompressions {
s.staticFiles[compression] = make(map[string]staticFile)
}
}
func (s *Site) addStaticFile(path, mediaType string, data []byte) errors.E {
if !strings.HasPrefix(path, "/") {
errE := errors.New(`path does not start with "/"`)
errors.Details(errE)["path"] = path
return errE
}
_, ok := s.staticFiles[compressionIdentity][path]
if ok {
errE := errors.New(`static file for path already exists`)
errors.Details(errE)["path"] = path
return errE
}
compressions := allCompressions
if len(data) <= minCompressionSize {
compressions = []string{compressionIdentity}
}
for _, compression := range compressions {
d, errE := compress(compression, data)
if errE != nil {
errors.Details(errE)["path"] = path
return errE
}
// len(data) cannot be 0 for compression != compressionIdentity because
// 0 <= minCompressionSize and only compressionIdentity is tried then.
if compression != compressionIdentity && float64(len(d))/float64(len(data)) >= minCompressionRatio {
// No need to compress noncompressible files.
continue
}
s.staticFiles[compression][path] = staticFile{
Data: d,
Etag: computeEtag(d),
MediaType: mediaType,
}
}
return nil
}
type hasSite interface {
GetSite() *Site
}
// We use a helper to create SiteT and a pointer to its internal Site
// to make it work with current Go type system limitations. Because we
// do not use this in critical paths, use of reflect seems reasonable.
//
// See: https://go.dev/play/p/j0GRRI96WMM
// See: https://github.com/golang/go/issues/63708
func newSiteT[SiteT hasSite]() (SiteT, *Site) { //nolint:ireturn
typ := reflect.TypeOf((*SiteT)(nil)).Elem().Elem()
st := reflect.New(typ).Interface().(SiteT) //nolint:forcetypeassert,errcheck
site := st.GetSite()
return st, site
}
func newCORS(options *CORSOptions) *cors.Cors {
if options == nil {
return nil
}
return cors.New(cors.Options{ //nolint:exhaustruct
AllowedOrigins: options.AllowedOrigins,
AllowedMethods: options.GetAllowedMethods(),
AllowedHeaders: options.AllowedHeaders,
ExposedHeaders: options.ExposedHeaders,
MaxAge: options.MaxAge,
AllowCredentials: options.AllowCredentials,
AllowPrivateNetwork: options.AllowPrivateNetwork,
OptionsSuccessStatus: options.OptionsSuccessStatus,
// We always passthrough and call w.WriteHeader ourselves,
// unless there is API OPTIONS handler which we then call instead.
OptionsPassthrough: true,
})
}
func wrapGetCORS(options *CORSOptions, h func(http.ResponseWriter, *http.Request, Params)) (
func(http.ResponseWriter, *http.Request, Params),
func(http.ResponseWriter, *http.Request, Params),
) {
c := newCORS(options)
optionsSuccessStatus := options.OptionsSuccessStatus
if optionsSuccessStatus == 0 {
optionsSuccessStatus = http.StatusNoContent
}
return func(w http.ResponseWriter, r *http.Request, params Params) {
// Non-OPTIONS request.
c.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h(w, r, params)
})).ServeHTTP(w, r)
}, func(w http.ResponseWriter, r *http.Request, _ Params) {
// OPTIONS request.
c.Handler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
// We do nothing after OPTIONS request has been handled,
// even if it was not a CORS OPTIONS request.
w.WriteHeader(optionsSuccessStatus)
})).ServeHTTP(w, r)
}
}
func wrapCORS(c *cors.Cors, h func(http.ResponseWriter, *http.Request, Params)) func(http.ResponseWriter, *http.Request, Params) {
return func(w http.ResponseWriter, r *http.Request, params Params) {
c.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h(w, r, params)
})).ServeHTTP(w, r)
}
}
func methodsSubset(options *CORSOptions, methodsWithHandlers []string) errors.E {
allowedMethods := mapset.NewThreadUnsafeSet[string]()
allowedMethods.Append(options.GetAllowedMethods()...)
methods := mapset.NewThreadUnsafeSet[string]()
methods.Append(methodsWithHandlers...)
extraMethods := allowedMethods.Difference(methods)
if extraMethods.Cardinality() > 0 {
errE := errors.New("CORS allowed methods contain methods without handlers")
extra := extraMethods.ToSlice()
slices.Sort(extra)
errors.Details(errE)["extra"] = extra
return errE
}
return nil
}
// Service defines the application logic for your service.
//
// You should embed the Service struct inside your service struct on which
// you define handlers as methods with [Handler] signature. Handlers together
// with StaticFiles, Routes and Sites define how should the service handle HTTP
// requests.
type Service[SiteT hasSite] struct {
// General logger for the service.
Logger zerolog.Logger
// Canonical log line logger for the service which logs one log entry per
// request. It is automatically populated with data about the request.
CanonicalLogger zerolog.Logger
// WithContext is a function which adds to the context a logger.
// It is then accessible using zerolog.Ctx(ctx).
// The first function is called when the request is handled and allows
// any cleanup necessary. The second function is called on panic.
// If WithContext is not set, Logger is used instead.
WithContext func(context.Context) (context.Context, func(), func()) `exhaustruct:"optional"`
// StaticFiles to be served by the service. All paths are anchored at / when served.
// HTML files (those with ".html" extension) are rendered using html/template
// with site struct as data. Other files are served as-is.
StaticFiles fs.ReadFileFS
// Routes to be handled by the service and mapped to its Handler methods.
Routes []Route
// Sites configured for the service. Key in the map must match site's domain.
// This should generally be set to sites returned from Server.Init method.
Sites map[string]SiteT
// Middleware is a chain of additional middleware to append before the router.
Middleware []func(http.Handler) http.Handler `exhaustruct:"optional"`
// SiteContextPath is the path at which site context (JSON of site struct)
// should be added to static files.
SiteContextPath string `exhaustruct:"optional"`
// MetadataHeaderPrefix is an optional prefix to the Metadata response header.
MetadataHeaderPrefix string `exhaustruct:"optional"`
// ProxyStaticTo is a base URL to proxy to during development, if set.
// This should generally be set to result of Server.ProxyToInDevelopment method.
// If set, StaticFiles are not served by the service so that they can be proxied instead.
ProxyStaticTo string `exhaustruct:"optional"`
// IsImmutableFile should return true if the static file is immutable and
// should have such caching headers. Static files are those which do not change
// during a runtime of the program. Immutable files are those which are never changed.
IsImmutableFile func(path string) bool `exhaustruct:"optional"`
// SkipServingFile should return true if the static file should not be automatically
// registered with the router to be served. It can still be served using ServeStaticFile.
SkipServingFile func(path string) bool `exhaustruct:"optional"`
router *Router `exhaustruct:"optional"`
reverseProxy *httputil.ReverseProxy `exhaustruct:"optional"`
}
// RouteWith registers static files and handlers with the router based on Routes and service [Handler]
// methods and returns a [http.Handler] to be used with the [Server].
//
// You should generally pass your service struct with embedded Service struct as service
// parameter so that handler methods can be detected. Non-API handler methods should
// have the same name as the route. While API handler methods should have the name
// matching the route name with HTTP method name as suffix (e.g., "CommentPost" for
// route with name "Comment" and POST HTTP method).
func (s *Service[SiteT]) RouteWith(service interface{}, router *Router) (http.Handler, errors.E) {
if s.router != nil {
return nil, errors.New("RouteWith called more than once")
}
s.router = router
errE := s.configureRoutes(service)
if errE != nil {
return nil, errE
}
if s.ProxyStaticTo != "" {
s.Logger.Debug().Str("proxy", s.ProxyStaticTo).Msg("proxying static files")
errE := s.renderAndCompressSiteContext()
if errE != nil {
return nil, errE
}
errE = s.serveStaticFiles()
if errE != nil {
return nil, errE
}
errE = s.makeReverseProxy()
if errE != nil {
return nil, errE
}
p := logHandlerFuncName("Proxy", s.Proxy)
s.router.NotFound = p
s.router.MethodNotAllowed = func(w http.ResponseWriter, req *http.Request, _ Params, _ []string) {
p(w, req)
}
} else {
errE := s.renderAndCompressStaticFiles()
if errE != nil {
return nil, errE
}
errE = s.renderAndCompressSiteContext()
if errE != nil {
return nil, errE
}
errE = s.serveStaticFiles()
if errE != nil {
return nil, errE
}
if s.router.NotFound == nil {
s.router.NotFound = logHandlerFuncName("NotFound", s.NotFound)
} else {
s.router.NotFound = logHandlerFuncName("NotFound", s.router.NotFound)
}
if s.router.MethodNotAllowed == nil {
s.router.MethodNotAllowed = func(w http.ResponseWriter, req *http.Request, _ Params, allow []string) {
*canonicalLoggerMessage(req.Context()) = "MethodNotAllowed" //nolint:goconst
s.MethodNotAllowed(w, req, allow)
}
} else {
m := s.router.MethodNotAllowed
s.router.MethodNotAllowed = func(w http.ResponseWriter, req *http.Request, params Params, allow []string) {
*canonicalLoggerMessage(req.Context()) = "MethodNotAllowed"
m(w, req, params, allow)
}
}
}
if s.router.Panic == nil {
s.router.Panic = s.handlePanic
}
c := alice.New()
// We first create a canonical log line logger as context logger.
c = c.Append(hlog.NewHandler(s.CanonicalLogger))
// Then we set the canonical log line logger under its own context key as well.
c = c.Append(setCanonicalLogger)
// It has to be before accessHandler so that it can access the metrics context.
c = c.Append(metricsMiddleware)
// Is logger enabled at all (not zerolog.Nop or zero zerolog struct)?
// See: https://github.com/rs/zerolog/pull/617
if l := s.CanonicalLogger.Sample(nil); l.Log().Enabled() { //nolint:zerologlint
c = c.Append(accessHandler(func(req *http.Request, code int, responseBody, requestBody int64, duration time.Duration) {
ctx := req.Context()
level := zerolog.InfoLevel
if code >= http.StatusBadRequest {
level = zerolog.WarnLevel
}
if code >= http.StatusInternalServerError {
level = zerolog.ErrorLevel
}
metrics := MustGetMetrics(ctx)
metrics.Duration(MetricTotal).Duration = duration
l := zerolog.Ctx(ctx).WithLevel(level) //nolint:zerologlint
if code != 0 {
l = l.Int("code", code)
}
l = l.Int64("responseBody", responseBody).
Int64("requestBody", requestBody).
Object("metrics", metrics)
message := canonicalLoggerMessage(ctx)
if *message != "" {
l.Msg(*message)
} else {
l.Send()
}
}))
c = c.Append(logMetadata(s.MetadataHeaderPrefix))
c = c.Append(websocketHandler("ws"))
c = c.Append(hlog.MethodHandler("method"))
c = c.Append(urlHandler("path"))
c = c.Append(hlog.RemoteIPHandler("client"))
c = c.Append(hlog.UserAgentHandler("agent"))
c = c.Append(hlog.RefererHandler("referer"))
c = c.Append(connectionIDHandler("connection"))
c = c.Append(requestIDHandler("request", "Request-Id"))
c = c.Append(hlog.HTTPVersionHandler("proto"))
c = c.Append(hlog.HostHandler("host", true))
c = c.Append(hlog.EtagHandler("etag"))
c = c.Append(hlog.ResponseHeaderHandler("encoding", "Content-Encoding"))
} else {
c = c.Append(accessHandler(func(req *http.Request, _ int, _, _ int64, duration time.Duration) {
ctx := req.Context()
metrics := MustGetMetrics(ctx)
metrics.Duration(MetricTotal).Duration = duration
}))
c = c.Append(requestIDHandler("", "Request-Id"))
}
c = c.Append(addNosniffHeader)
// parseForm should be towards the end because it can fail or redirect
// and we want other fields to be logged. It also logs query string and
// redirects to canonical query strings.
c = c.Append(s.parseForm("query", "rawQuery"))
// validatePath should be towards the end because it can fail or redirect
// and we want other fields to be logged. It redirects to canonical path.
c = c.Append(s.validatePath)
// validateSite should be towards the end because it can fail and we want
// other fields to be logged.
c = c.Append(s.validateSite)
// We replace the canonical log line logger with a new context logger, but with associated request ID.
// The canonical log line logger is still available under its own context key.
if s.WithContext != nil {
c = c.Append(z.NewHandler(s.WithContext))
} else {
c = c.Append(hlog.NewHandler(s.Logger))
}
c = c.Append(requestIDHandler("request", ""))
for _, m := range s.Middleware {
c = c.Append(m)
}
return c.Then(s.router), nil
}
func (s *Service[SiteT]) configureRoutes(service interface{}) errors.E {
v := reflect.ValueOf(service)
for _, route := range s.Routes {
if route.Get == nil && route.API == nil {
errE := errors.New(`at least one of "get" and "api" has to be set`)
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
if route.Get != nil {
handlerName := route.Name
m := v.MethodByName(handlerName)
if !m.IsValid() {
errE := errors.New("handler not found")
errors.Details(errE)["handler"] = handlerName
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
s.Logger.Debug().Str("handler", handlerName).Str("route", route.Name).Str("path", route.Path).Msg("route registration: handler found")
// We cannot use Handler here because it is a named type.
h, ok := m.Interface().(func(http.ResponseWriter, *http.Request, Params))
if !ok {
errE := errors.New("invalid handler type")
errors.Details(errE)["handler"] = handlerName
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
errors.Details(errE)["type"] = fmt.Sprintf("%T", m.Interface())
return errE
}
if route.Get.CORS != nil {
errE := methodsSubset(route.Get.CORS, []string{http.MethodGet, http.MethodHead})
if errE != nil {
errors.Details(errE)["handler"] = handlerName
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
var optionsH func(http.ResponseWriter, *http.Request, Params)
h, optionsH = wrapGetCORS(route.Get.CORS, h)
optionsH = logHandlerName(handlerName, optionsH)
errE = s.router.Handle(route.Name, http.MethodOptions, route.Path, false, optionsH)
if errE != nil {
errors.Details(errE)["handler"] = handlerName
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
}
h = logHandlerName(handlerName, h)
// HEAD method is already handled by the router for non-API requests.
errE := s.router.Handle(route.Name, http.MethodGet, route.Path, false, h)
if errE != nil {
errors.Details(errE)["handler"] = handlerName
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
}
if route.API != nil { //nolint:nestif
c := newCORS(route.API.CORS)
foundAnyAPIHandler := false
foundOptionsHandler := false
foundMethods := []string{}
// MethodHead is handled by MethodGet handled.
for _, method := range []string{
http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch,
http.MethodDelete, http.MethodConnect, http.MethodOptions, http.MethodTrace,
} {
handlerName := fmt.Sprintf("%s%s", route.Name, strings.Title(strings.ToLower(method))) //nolint:staticcheck
m := v.MethodByName(handlerName)
if !m.IsValid() {
s.Logger.Debug().Str("handler", handlerName).Str("route", route.Name).Str("path", route.Path).Msg("route registration: API handler not found")
continue
}
s.Logger.Debug().Str("handler", handlerName).Str("route", route.Name).Str("path", route.Path).Msg("route registration: API handler found")
foundAnyAPIHandler = true
// We cannot use Handler here because it is a named type.
h, ok := m.Interface().(func(http.ResponseWriter, *http.Request, Params))
if !ok {
errE := errors.New("invalid API handler type")
errors.Details(errE)["handler"] = handlerName
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
errors.Details(errE)["type"] = fmt.Sprintf("%T", m.Interface())
return errE
}
if c != nil {
h = wrapCORS(c, h)
if method == http.MethodOptions {
foundOptionsHandler = true
}
}
h = logHandlerName(handlerName, h)
errE := s.router.Handle(route.Name, method, route.Path, true, h)
if errE != nil {
errors.Details(errE)["handler"] = handlerName
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
foundMethods = append(foundMethods, method)
if method == http.MethodGet {
errE := s.router.Handle(route.Name, http.MethodHead, route.Path, true, h)
if errE != nil {
errors.Details(errE)["handler"] = handlerName
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
foundMethods = append(foundMethods, http.MethodHead)
}
}
if !foundAnyAPIHandler {
errE := errors.New("no API handler found")
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
if c != nil {
if !foundOptionsHandler {
handlerName := fmt.Sprintf("%s%s", route.Name, strings.Title(strings.ToLower(http.MethodOptions))) //nolint:staticcheck
optionsSuccessStatus := route.API.CORS.OptionsSuccessStatus
if optionsSuccessStatus == 0 {
optionsSuccessStatus = http.StatusNoContent
}
h := func(w http.ResponseWriter, r *http.Request, _ Params) {
c.Handler(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
// We do nothing after OPTIONS request has been handled,
// even if it was not a CORS OPTIONS request.
w.WriteHeader(optionsSuccessStatus)
})).ServeHTTP(w, r)
}
h = logHandlerName(handlerName, h)
errE := s.router.Handle(route.Name, http.MethodOptions, route.Path, true, h)
if errE != nil {
errors.Details(errE)["handler"] = handlerName
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
}
errE := methodsSubset(route.API.CORS, foundMethods)
if errE != nil {
errors.Details(errE)["route"] = route.Name
errors.Details(errE)["path"] = route.Path
return errE
}
}
}
}
return nil
}
func (s *Service[SiteT]) renderAndCompressStaticFiles() errors.E {
for _, siteT := range s.Sites {
site := siteT.GetSite()
if site.staticFiles != nil {
return errors.New("renderAndCompressStaticFiles called more than once")
}
site.initializeStaticFiles()
}
if s.StaticFiles == nil {
return nil
}
err := fs.WalkDir(s.StaticFiles, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return errors.WithStack(err)
}
if d.IsDir() {
return nil
}
pathWithSlash := "/" + path
data, err := s.StaticFiles.ReadFile(path)
if err != nil {
errE := errors.WithStack(err)
errors.Details(errE)["path"] = pathWithSlash
return errE
}
mediaType := mime.TypeByExtension(filepath.Ext(path))
if mediaType == "" {
s.Logger.Debug().Str("path", pathWithSlash).Msg("unable to determine content type for static file")
mediaType = "application/octet-stream"
}
// Each site might render HTML files differently.
if strings.HasSuffix(pathWithSlash, ".html") {
for _, siteT := range s.Sites {
site := siteT.GetSite()
htmlData, errE := s.render(pathWithSlash, data, siteT)
if errE != nil {
return errE
}
errE = site.addStaticFile(pathWithSlash, mediaType, htmlData)
if errE != nil {
return errE
}
}
} else {
// We do not use Site.addFile here so that we can reuse and deduplicate compressed
// static files across all sites by inverting loops (here we first iterate over
// compressions and then over sites).
compressions := allCompressions
if len(data) <= minCompressionSize {
compressions = []string{compressionIdentity}
}
for _, compression := range compressions {
d, errE := compress(compression, data)
if errE != nil {
errors.Details(errE)["path"] = pathWithSlash
return errE
}
// len(data) cannot be 0 for compression != compressionIdentity because
// 0 <= minCompressionSize and only compressionIdentity is tried then.
if compression != compressionIdentity && float64(len(d))/float64(len(data)) >= minCompressionRatio {
// No need to compress noncompressible files.
continue
}
etag := computeEtag(d)
for _, siteT := range s.Sites {
site := siteT.GetSite()
site.staticFiles[compression][pathWithSlash] = staticFile{
Data: d,
Etag: etag,
MediaType: mediaType,
}
}
}
}
s.Logger.Debug().Str("path", pathWithSlash).Msg("added file to static files")
return nil
})
return errors.WithStack(err)
}
func (s *Service[SiteT]) renderAndCompressSiteContext() errors.E {
if s.SiteContextPath == "" {
return nil
}
for _, siteT := range s.Sites {
site := siteT.GetSite()
// In development, this method could be called first and static files are not yet
// initialized (as requests for other static files are proxied), while in production
// static files has already been initialized and populated.
if site.staticFiles == nil {
site.initializeStaticFiles()
}
data, errE := x.MarshalWithoutEscapeHTML(siteT)
if errE != nil {
return errE
}
errE = site.addStaticFile(s.SiteContextPath, "application/json", data)
if errE != nil {
return errE
}
}
s.Logger.Debug().Str("path", s.SiteContextPath).Msg("added file to static files")
return nil
}
func (s *Service[SiteT]) makeReverseProxy() errors.E {
if s.reverseProxy != nil {
return errors.New("makeReverseProxy called more than once")
}
target, err := url.Parse(s.ProxyStaticTo)
if err != nil {
errE := errors.WithStack(err)
errors.Details(errE)["url"] = s.ProxyStaticTo
return errE
}
singleHostDirector := httputil.NewSingleHostReverseProxy(target).Director
director := func(req *http.Request) {
singleHostDirector(req)
// We pass request ID through.
req.Header.Set("Request-Id", MustRequestID(req.Context()).String())
// We potentially parse PostForm in parseForm middleware. In that case
// the body is consumed and closed. We have to reconstruct it here.
if postFormParsed(req) {
encoded := req.PostForm.Encode()
req.Body = io.NopCloser(strings.NewReader(encoded))
if req.Header.Get("Content-Length") != "" {
// Our reconstruction might have a different length.
req.Header.Set("Content-Length", strconv.Itoa(len(encoded)))
}
}
// TODO: Map origin and other headers.
}
// TODO: Map response cookies, other headers which include origin, and redirect locations.
s.reverseProxy = &httputil.ReverseProxy{
Rewrite: nil,
Director: director,
Transport: cleanhttp.DefaultPooledTransport(),
FlushInterval: -1,
ErrorLog: log.New(s.Logger, "", 0),
BufferPool: nil,
ModifyResponse: nil,
ErrorHandler: nil,
}
return nil
}
func (s *Service[SiteT]) serveStaticFiles() errors.E {
staticH := logHandlerName("StaticFile", toHandler(s.staticFile))
immutableH := logHandlerName("ImmutableFile", toHandler(s.immutableFile))
for _, siteT := range s.Sites {
site := siteT.GetSite()
// We can use any compression to obtain all static paths, so we use compressionIdentity.
for path := range site.staticFiles[compressionIdentity] {
if s.SkipServingFile != nil && s.SkipServingFile(path) {
continue
}
var n string
var h Handler
if s.IsImmutableFile != nil && s.IsImmutableFile(path) {
n = "ImmutableFile:" + path
h = immutableH
} else {
n = "StaticFile:" + path
h = staticH
}
err := s.router.Handle(n, http.MethodGet, path, false, h)
if err != nil {
return errors.WithDetails(err, "path", path)
}
}
// We can use any site to obtain all static paths,
// so we break here after the first site.
break
}
return nil
}
func (s *Service[SiteT]) render(path string, data []byte, siteT SiteT) ([]byte, errors.E) {
t, err := template.New(path).Parse(string(data))
if err != nil {
return nil, errors.WithDetails(err, "path", path)
}
var out bytes.Buffer
err = t.Execute(&out, siteT)
if err != nil {
return nil, errors.WithDetails(err, "path", path)
}
return out.Bytes(), nil
}
// AddMetadata adds header with metadata to the response.
//
// Metadata is encoded based on [RFC 8941]. Header name is "Metadata" with
// optional MetadataHeaderPrefix.
//
// [RFC 8941]: https://www.rfc-editor.org/rfc/rfc8941
func (s *Service[SiteT]) AddMetadata(w http.ResponseWriter, req *http.Request, metadata map[string]interface{}) ([]byte, errors.E) {
if len(metadata) == 0 {
return nil, nil
}
b := &bytes.Buffer{}
err := encodeMetadata(metadata, b)
if err != nil {
return nil, err
}
w.Header().Add(s.MetadataHeaderPrefix+metadataHeader, b.String())
logMetadata, ok := req.Context().Value(metadataContextKey).(map[string]interface{})
// metadataContextKey might not exist if provided logger is disabled.
if ok {
for key, value := range metadata {
// We overwrite any existing key. This is the same behavior RFC 8941 specifies
// for duplicate keys in its dictionaries. The last one wins.
logMetadata[key] = value
}
}
return b.Bytes(), nil
}
// PrepareJSON prepares the JSON response to the request. It populates
// response headers and encodes data as JSON.
// Optional metadata is added as the response header.
//
// Besides other types, data can be of type []byte and [json.RawMessage] in which
// case it is expected that it already contains a well-formed JSON and is returned
// as-is.
//
// If there is an error, PrepareJSON responds to the request and returns nil.
func (s *Service[SiteT]) PrepareJSON(w http.ResponseWriter, req *http.Request, data interface{}, metadata map[string]interface{}) []byte {
ctx := req.Context()
metrics := MustGetMetrics(ctx)
var encoded []byte
switch d := data.(type) {
case []byte:
encoded = d
case json.RawMessage:
encoded = []byte(d)
default:
m := metrics.Duration(MetricJSONMarshal).Start()
e, err := x.MarshalWithoutEscapeHTML(data)
m.Stop()
if err != nil {
s.InternalServerErrorWithError(w, req, errors.WithStack(err))
return nil
}
encoded = e
}
contentEncoding := negotiateContentEncoding(w, req, nil)
if contentEncoding == "" {
// If the client does not accept any compression we support (even no compression),
// we ignore that and just do not compress.
contentEncoding = compressionIdentity
} else if len(encoded) <= minCompressionSize {
contentEncoding = compressionIdentity
}
if contentEncoding != compressionIdentity {
m := metrics.Duration(MetricCompress).Start()
compressed, errE := compress(contentEncoding, encoded)
m.Stop()
if errE != nil {
s.InternalServerErrorWithError(w, req, errE)
return nil
}
// len(encoded) cannot be 0 because 0 <= minCompressionSize
// and contentEncoding is set to compressionIdentity then.
if float64(len(compressed))/float64(len(encoded)) >= minCompressionRatio {
// No need to send noncompressible files. We already used time to compress
// but we throw that away so that the client does not have to spend time decompressing.
// We do not try if any other acceptable compression might have
// a better ratio to not take too much time trying them. We assume
// that the client prefers generally the best compression anyway.
contentEncoding = compressionIdentity
} else {
encoded = compressed
}
}
md, errE := s.AddMetadata(w, req, metadata)
if errE != nil {
s.InternalServerErrorWithError(w, req, errE)
return nil
}