-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
client.go
180 lines (151 loc) · 4.92 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
package gateway
import (
"errors"
"fmt"
"github.com/discordpkg/gateway/encoding"
"io"
"runtime"
"github.com/discordpkg/gateway/event"
"github.com/discordpkg/gateway/intent"
"github.com/discordpkg/gateway/internal/util"
)
var ErrOutOfSync = errors.New("sequence number was out of sync")
var ErrNotConnectedYet = errors.New("client is not in a connected state")
func NewClient(options ...Option) (*Client, error) {
client := &Client{
allowlist: util.Set[event.Type]{},
logger: &nopLogger{},
}
client.ctx = &StateCtx{client: client}
for i := range options {
if err := options[i](client); err != nil {
return nil, err
}
}
client.ctx.logger = client.logger // ugh..
if client.botToken == "" {
return nil, errors.New("missing bot token")
}
// rate limits
if client.commandRateLimiter == nil {
return nil, errors.New("missing command rate limiter - try 'gatewayutil.NewCommandRateLimiter()'")
}
if client.identifyRateLimiter == nil {
return nil, errors.New("missing identify rate limiter - try 'gatewayutil.NewLocalIdentifyRateLimiter()'")
}
// connection properties
if client.connectionProperties == nil {
client.connectionProperties = &IdentifyConnectionProperties{
OS: runtime.GOOS,
Browser: "github.com/discordpkg/gateway",
Device: "github.com/discordpkg/gateway",
}
}
// heartbeat
if client.heartbeatHandler == nil {
return nil, errors.New("missing heartbeat handler - use WithHeartbeatHandler")
}
// sharding
if client.totalNumberOfShards == 0 {
if client.id == 0 {
client.totalNumberOfShards = 1
} else {
return nil, errors.New("missing shard count")
}
}
if int(client.id) > client.totalNumberOfShards {
return nil, errors.New("shard id is higher than shard count")
}
if client.ctx.state == nil {
client.ctx.SetState(&HelloState{
ctx: client.ctx,
Identity: &Identify{
BotToken: client.botToken,
Properties: &client.connectionProperties,
Compress: false,
LargeThreshold: 0,
Shard: [2]int{int(client.id), client.totalNumberOfShards},
Presence: nil,
Intents: client.intents,
},
})
}
return client, nil
}
// Client provides a user target interface, for simplified Discord interaction.
//
// Note: It's not suitable for internal processes/states.
type Client struct {
botToken string
id ShardID
totalNumberOfShards int
connectionProperties interface{}
intents intent.Type
allowlist util.Set[event.Type]
eventHandler Handler
commandRateLimiter RateLimiter
identifyRateLimiter RateLimiter
heartbeatHandler HeartbeatHandler
ctx *StateCtx
logger Logger
}
func (c *Client) String() string {
data := ""
data += fmt.Sprintln(fmt.Sprintf("shard %d out of %d shards", c.id, c.totalNumberOfShards))
data += fmt.Sprintln("intents:", c.intents)
data += fmt.Sprintln("events:", c.intents)
return data
}
// ResumeURL returns the URL to be used when dialing a new websocket connection. An empty string
// is returned when the shard can not be resumed, and you should instead use "Get Gateway Bot" endpoint to fetch
// the correct URL for connecting.
//
// The client is assumed to have been correctly closed before calling this.
func (c *Client) ResumeURL() string {
if _, ok := c.ctx.state.(*ResumableClosedState); ok {
return c.ctx.ResumeGatewayURL
}
return ""
}
func (c *Client) Close(closeWriter io.Writer) error {
return c.ctx.Close(closeWriter)
}
func (c *Client) read(client io.Reader) (*Payload, int, error) {
data, err := io.ReadAll(client)
if err != nil {
return nil, 0, fmt.Errorf("failed to read data. %w", err)
}
packet := &Payload{}
if err = encoding.Unmarshal(data, packet); err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal packet. %w", err)
}
return packet, len(data), nil
}
func (c *Client) process(payload *Payload, pipe io.Writer) (err error) {
// we consider 0 to be either the first message or messages without sequence numbers such as heartbeat ack
if c.ctx.sequenceNumber.Load() == 0 || c.ctx.sequenceNumber.CompareAndSwap(payload.Seq-1, payload.Seq) {
return c.ctx.Process(payload, pipe)
} else if c.ctx.sequenceNumber.Load() >= payload.Seq {
// already handled
return nil
}
c.ctx.SetState(&ClosedState{})
return ErrOutOfSync
}
// ProcessNext processes the next Discord message and update state accordingly. On error, you are expected to call
// Client.Close to notify Discord about any issues accumulated in the Client.
func (c *Client) ProcessNext(reader io.Reader, writer io.Writer) (*Payload, error) {
payload, _, err := c.read(reader)
if err != nil {
c.ctx.SetState(&ClosedState{})
return nil, err
}
c.logger.Debug("processing payload: %s", payload)
return payload, c.process(payload, writer)
}
func (c *Client) Write(pipe io.Writer, evt event.Type, payload encoding.RawMessage) error {
if _, ok := c.ctx.state.(*ConnectedState); !ok {
return ErrNotConnectedYet
}
return c.ctx.Write(pipe, evt, payload)
}