-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.go
263 lines (237 loc) · 7.42 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
package minirpc
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"minirpc/codec"
"net"
"net/http"
"reflect"
"strings"
"sync"
"time"
)
const MagicNumber = 0x3bef5c
const (
connected = "200 Connected to Mini RPC"
defaultRPCPath = "/_minirpc_"
defaultDebugPath = "/debug/minirpc"
)
type Option struct {
MagicNumber int // MagicNumber marks this's a minirpc request
CodecType codec.Type // client may choose different Codec to encode body
ConnectTimeout time.Duration // 0 means no limit
HandleTimeout time.Duration
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
ConnectTimeout: time.Second * 10,
}
// Server represents an RPC Server
type Server struct {
serviceMap sync.Map
}
func NewServer() *Server {
return &Server{}
}
var DefaultServer = NewServer()
// ServeConn runs the server on a single connection
// ServeConn blocks, serving the connection until the client hangs up
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() { _ = conn.Close() }()
var opt Option
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Println("rpc server: options error: ", err)
return
}
if opt.MagicNumber != MagicNumber {
log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
return
}
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
log.Printf("rpc server: invalid codec type %s", opt.CodecType)
return
}
server.serveCodec(f(conn), &opt)
}
// invalidRequest is a placeholder for response argv when error occurs
var invalidRequest = struct{}{}
func (server *Server) serveCodec(cc codec.Codec, opt *Option) {
sending := new(sync.Mutex) // make sure to send a complete response
wg := new(sync.WaitGroup) // wait until all request are handled
for {
// 读取请求
req, err := server.readRequest(cc)
if err != nil {
if req == nil {
break // it's not possible to recover, so close the connection
}
req.h.Error = err.Error()
// 回复请求报文必须逐个发送,并发易导致客户端无法解析,使用锁 sending 保证
server.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
// 并发处理请求
go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout)
}
wg.Wait()
_ = cc.Close()
}
// stores all information of a call
type request struct {
h *codec.Header
argv, replyv reflect.Value // argv and replyv of request
mtype *methodType
svc *service
}
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println("rpc server: read header error: ", err)
}
return nil, err
}
return &h, nil
}
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.svc, req.mtype, err = server.findService(h.ServiceMethod)
if err != nil {
return req, err
}
// 创建两个入参实例
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()
// make sure that argvi is a pointer, ReadBody need a pointer as parameter
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}
// 将请求报文反序列化为第一个入参 argv
if err = cc.ReadBody(argvi); err != nil {
log.Println("rpc server: read body err: ", err)
return req, err
}
return req, nil
}
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
log.Println("rpc server: write response err: ", err)
}
}
// 处理超时实现与客户端类似,使用 time.After 结合 select + chan
// 这里需要确保 sendResponse 仅调用一次,因此将整个过程拆分为 called 和 sent 两个阶段,在这段代码中只会发生如下两种情况:
// called 信道接收到消息,代表处理没有超时,继续执行 sendResponse
// time.After() 先于 called 接收到消息,说明处理已经超时,called 和 sent 都将被阻塞。在 case <-time.After(timeout) 处调用 sendResponse
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
called := make(chan struct{})
sent := make(chan struct{})
go func() {
err := req.svc.call(req.mtype, req.argv, req.replyv)
called <- struct{}{}
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
sent <- struct{}{}
return
}
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
}()
if timeout == 0 {
<-called
<-sent
return
}
select {
case <-time.After(timeout):
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
server.sendResponse(cc, req.h, invalidRequest, sending)
case <-called:
<-sent
}
}
// accepts connections on the listener and serves requests for each incomming connection.
func (server *Server) Accept(lis net.Listener) {
// for 循环等待 socket 连接建立并开启子协程处理,处理过程交给 ServeConn
for {
conn, err := lis.Accept()
if err != nil {
log.Println("rpc server: accept error: ", err)
return
}
go server.ServeConn(conn)
}
}
func Accept(lis net.Listener) {
DefaultServer.Accept(lis)
}
func (server *Server) Register(rcvr interface{}) error {
s := newService(rcvr)
if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup {
return errors.New("rpc: service already defined: " + s.name)
}
return nil
}
// Register publishes the receiver's methods in the DefaultServer
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
// 对Service.Method 进行分割
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod)
return
}
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
// 从 serviceMap 中找对应 service 实例
svci, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc server: can't find service " + serviceName)
return
}
svc = svci.(*service)
// 再从实例 method 中找对应
mtype = svc.method[methodName]
if mtype == nil {
err = errors.New("rpc server: can't find method " + methodName)
}
return
}
// ServeHTTP implements an http.Handler that answers RPC requests.
func (server *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = io.WriteString(w, "405 must CONNECT\n")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Print("rpc hijacking ", r.RemoteAddr, ": ", err.Error())
return
}
_, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.ServeConn(conn)
}
// HandleHTTP registers an HTTP handler for RPC messages on rpcPath
// It's still necessary to invoke http.Serve(), typically in a go statement
func (server *Server) HandleHTTP() {
http.Handle(defaultRPCPath, server)
}
// HandleHTTP is a convenient approach for default server to register HTTP handlers
func HandleHTTP() {
DefaultServer.HandleHTTP()
}