-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.go
342 lines (318 loc) · 9.33 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
package goRPC
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"goRPC/Maincodec/codec"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
)
type Call struct {
Seq uint64
ServiceMethod string // 函数标准"<service>.<method>"
Args interface{} // 参数
Reply interface{} // 返回值
Error error //报错
Done chan *Call // 呼叫完成时选通
}
// 支持异步调用,当调用结束的时候用call.done通知调用方
func (call *Call) done() {
call.Done <- call
}
type Client struct {
cc codec.Codec //编码器
opt *Option //编码器选择
sending sync.Mutex //互斥锁,和服务端类似,为了保证请求的有序发送
header codec.Header //消息头
mu sync.Mutex
seq uint64 //用于给发送的请求编号,每个请求拥有唯一编号
pending map[uint64]*Call //存储未处理完的请求,key是seq
//closing 和 shutdown 任意一个值置为 true,则表示 Client 处于不可用的状态
closing bool // 用户主动关闭的
shutdown bool // shutdown 置为 true 一般是有错误发生
}
// io.Closer需要实现
var _ io.Closer = (*Client)(nil)
// 报错信息
var ErrShutdown = errors.New("connection is shut down")
// 关闭连接
func (client *Client) Close() error {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing {
return ErrShutdown
}
client.closing = true
//关闭链接实例
return client.cc.Close()
}
func (client *Client) IsAvailable() bool {
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}
func (client *Client) registerCall(call *Call) (uint64, error) {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing || client.shutdown {
return 0, ErrShutdown
}
//将参数 call 添加到 client.pending 中,并更新 client.seq
call.Seq = client.seq
client.pending[call.Seq] = call
client.seq++
return call.Seq, nil
}
func (client *Client) removeCall(seq uint64) *Call {
client.mu.Lock()
defer client.mu.Unlock()
//根据 seq,从 client.pending 中移除对应的 call,并返回
call := client.pending[seq]
delete(client.pending, seq)
return call
}
// 服务端或客户端发生错误时调用,将 shutdown 设置为 true,且将错误信息通知所有 pending 状态的 call。
func (client *Client) terminateCalls(err error) {
client.sending.Lock()
defer client.sending.Unlock()
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for _, call := range client.pending {
call.Error = err
call.done()
}
}
func (client *Client) receive() {
var err error
for err == nil {
var h codec.Header
if err = client.cc.ReadHeader(&h); err != nil {
break
}
call := client.removeCall(h.Seq)
switch {
//call 不存在,可能是请求没有发送完整,或者因为其他原因被取消
case call == nil:
err = client.cc.ReadBody(nil)
//call 存在,但服务端处理出错,即 h.Error 不为空
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = client.cc.ReadBody(nil)
call.done()
//call 存在,服务端处理正常,那么需要从 body 中读取 Reply 的值。
default:
err = client.cc.ReadBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body" + err.Error())
}
call.done()
}
}
client.terminateCalls(err)
}
func NewClient(conn net.Conn, opt *Option) (*Client, error) {
//先完成协议交换
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
err := fmt.Errorf("invalid codec type %s", opt.CodecType)
log.Println("rpc client: codec error:", err)
return nil, err
}
// send options with server
if err := json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn), opt), nil
}
func newClientCodec(cc codec.Codec, opt *Option) *Client {
client := &Client{
seq: 1, // seq starts with 1, 0 means invalid call
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
//接收
go client.receive()
return client
}
// 接收参数
func parseOptions(opts ...*Option) (*Option, error) {
// if opts is nil or pass nil as parameter
if len(opts) == 0 || opts[0] == nil {
return DefaultOption, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = DefaultOption.MagicNumber
if opt.CodecType == "" {
opt.CodecType = DefaultOption.CodecType
}
return opt, nil
}
// Dial 函数,便于用户传入服务端地址,创建 Client 实例。
func Dial(network, address string, opts ...*Option) (*Client, error) {
return DialTimeout(NewClient, network, address, opts...)
}
/*
//初始版本
// Dial 函数,便于用户传入服务端地址,创建 Client 实例。
func Dial(network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if client == nil {
_ = conn.Close()
}
}()
return NewClient(conn, opt)
}
*/
func (client *Client) send(call *Call) {
//确保客户端将发送一个完整的请求
client.sending.Lock()
defer client.sending.Unlock()
// 注册请求创建其序列号
seq, err := client.registerCall(call)
if err != nil {
call.Error = err
call.done()
return
}
//创建请求头
client.header.ServiceMethod = call.ServiceMethod
client.header.Seq = seq
client.header.Error = ""
// 编码
if err := client.cc.Write(&client.header, call.Args); err != nil {
call := client.removeCall(seq)
// call为nil表示write部分失效
// 客户端收到响应并进行处理
if call != nil {
call.Error = err
call.done()
}
}
}
// Go 和 Call 是客户端暴露给用户的两个 RPC 服务调用接口,Go 是一个异步接口,返回 call 实例。
// go异步调用函数
// 返回Call结构体
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
client.send(call)
return call
}
// RPC调用函数名
// 返回错误值
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
client.removeCall(call.Seq)
return errors.New("rpc client: call failed: " + ctx.Err().Error())
case call := <-call.Done:
return call.Error
}
}
type clientResult struct {
client *Client
err error
}
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)
// 将 net.Dial 替换为 net.DialTimeout,如果连接创建超时,将返回错误。
// 2)使用子协程执行 NewClient,执行完成后则通过信道 ch 发送结果,如果 time.After() 信道先接收到消息,则说明 NewClient 执行超时,返回错误。
func DialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if err != nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(conn, opt)
ch <- clientResult{client: client, err: err}
}()
if opt.ConnectTimeout == 0 {
result := <-ch
return result.client, result.err
}
select {
case <-time.After(opt.ConnectTimeout): //先于 called 接收到消息,说明处理已经超时
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}
func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) {
_, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath))
// Require successful HTTP response
// before switching to RPC protocol.
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
//print(resp.Status)
if err == nil && resp.Status == connected {
return NewClient(conn, opt)
}
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
}
return nil, err
}
// DialHTTP连接到指定网络地址的HTTP RPC服务器
func DialHTTP(network, address string, opts ...*Option) (*Client, error) {
return DialTimeout(NewHTTPClient, network, address, opts...)
}
// XDial calls different functions to connect to a RPC server
// according the first parameter rpcAddr.
// rpcAddr is a general format (protocol@addr) to represent a rpc server
// eg, [email protected]:7001, [email protected]:9999, unix@/tmp/rpc.sock
func XDial(rpcAddr string, opts ...*Option) (*Client, error) {
parts := strings.Split(rpcAddr, "@")
if len(parts) != 2 {
return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr)
}
protocol, addr := parts[0], parts[1]
switch protocol {
case "http":
return DialHTTP("tcp", addr, opts...)
default:
// tcp, unix or other transport protocol
return Dial(protocol, addr, opts...)
}
}