From 3c0f25d54c220ef592d51e409dc8afdbeb072a9b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 14 Nov 2024 14:25:15 +0200 Subject: [PATCH] all: add some safety for client being nil --- appstate.go | 6 ++++++ client.go | 23 ++++++++++++++++++++--- errors.go | 1 + mediaconn.go | 3 +++ msgsecret.go | 6 ++++++ newsletter.go | 3 +++ pair-code.go | 3 +++ privacysettings.go | 2 +- qrchan.go | 4 +++- request.go | 3 +++ send.go | 6 +++++- sendfb.go | 4 ++++ user.go | 6 ++++++ 13 files changed, 64 insertions(+), 6 deletions(-) diff --git a/appstate.go b/appstate.go index 9f4561e6..3479ec67 100644 --- a/appstate.go +++ b/appstate.go @@ -24,6 +24,9 @@ import ( // FetchAppState fetches updates to the given type of app state. If fullSync is true, the current // cached state will be removed and all app state patches will be re-fetched from the server. func (cli *Client) FetchAppState(name appstate.WAPatchName, fullSync, onlyIfNotSynced bool) error { + if cli == nil { + return ErrClientIsNil + } cli.appStateSyncLock.Lock() defer cli.appStateSyncLock.Unlock() if fullSync { @@ -347,6 +350,9 @@ func (cli *Client) requestAppStateKeys(ctx context.Context, rawKeyIDs [][]byte) // // cli.SendAppState(appstate.BuildMute(targetJID, true, 24 * time.Hour)) func (cli *Client) SendAppState(patch appstate.PatchInfo) error { + if cli == nil { + return ErrClientIsNil + } version, hash, err := cli.Store.AppState.GetAppStateVersion(string(patch.Type)) if err != nil { return err diff --git a/client.go b/client.go index 6ba9b11d..1646a073 100644 --- a/client.go +++ b/client.go @@ -371,6 +371,9 @@ func (cli *Client) closeSocketWaitChan() { } func (cli *Client) getOwnID() types.JID { + if cli == nil { + return types.EmptyJID + } id := cli.Store.ID if id == nil { return types.EmptyJID @@ -379,6 +382,9 @@ func (cli *Client) getOwnID() types.JID { } func (cli *Client) WaitForConnection(timeout time.Duration) bool { + if cli == nil { + return false + } timeoutChan := time.After(timeout) cli.socketLock.RLock() for cli.socket == nil || !cli.socket.IsConnected() || !cli.IsLoggedIn() { @@ -398,6 +404,9 @@ func (cli *Client) WaitForConnection(timeout time.Duration) bool { // Connect connects the client to the WhatsApp web websocket. After connection, it will either // authenticate if there's data in the device store, or emit a QREvent to set up a new link. func (cli *Client) Connect() error { + if cli == nil { + return ErrClientIsNil + } cli.socketLock.Lock() defer cli.socketLock.Unlock() if cli.socket != nil { @@ -444,7 +453,7 @@ func (cli *Client) Connect() error { // IsLoggedIn returns true after the client is successfully connected and authenticated on WhatsApp. func (cli *Client) IsLoggedIn() bool { - return cli.isLoggedIn.Load() + return cli != nil && cli.isLoggedIn.Load() } func (cli *Client) onDisconnect(ns *socket.NoiseSocket, remote bool) { @@ -508,6 +517,9 @@ func (cli *Client) autoReconnect() { // IsConnected checks if the client is connected to the WhatsApp web websocket. // Note that this doesn't check if the client is authenticated. See the IsLoggedIn field for that. func (cli *Client) IsConnected() bool { + if cli == nil { + return false + } cli.socketLock.RLock() connected := cli.socket != nil && cli.socket.IsConnected() cli.socketLock.RUnlock() @@ -519,7 +531,7 @@ func (cli *Client) IsConnected() bool { // This will not emit any events, the Disconnected event is only used when the // connection is closed by the server or a network error. func (cli *Client) Disconnect() { - if cli.socket == nil { + if cli == nil || cli.socket == nil { return } cli.socketLock.Lock() @@ -544,7 +556,9 @@ func (cli *Client) unlockedDisconnect() { // Note that this will not emit any events. The LoggedOut event is only used for external logouts // (triggered by the user from the main device or by WhatsApp servers). func (cli *Client) Logout() error { - if cli.MessengerConfig != nil { + if cli == nil { + return ErrClientIsNil + } else if cli.MessengerConfig != nil { return errors.New("can't logout with Messenger credentials") } ownID := cli.getOwnID() @@ -728,6 +742,9 @@ func (cli *Client) handlerQueueLoop(ctx context.Context) { } func (cli *Client) sendNodeAndGetData(node waBinary.Node) ([]byte, error) { + if cli == nil { + return nil, ErrClientIsNil + } cli.socketLock.RLock() sock := cli.socket cli.socketLock.RUnlock() diff --git a/errors.go b/errors.go index 101436b7..00d63904 100644 --- a/errors.go +++ b/errors.go @@ -16,6 +16,7 @@ import ( // Miscellaneous errors var ( + ErrClientIsNil = errors.New("client is nil") ErrNoSession = errors.New("can't encrypt message for device: no signal session established") ErrIQTimedOut = errors.New("info query timed out") ErrNotConnected = errors.New("websocket not connected") diff --git a/mediaconn.go b/mediaconn.go index 2e833037..4576b22a 100644 --- a/mediaconn.go +++ b/mediaconn.go @@ -41,6 +41,9 @@ func (mc *MediaConn) Expiry() time.Time { } func (cli *Client) refreshMediaConn(force bool) (*MediaConn, error) { + if cli == nil { + return nil, ErrClientIsNil + } cli.mediaConnLock.Lock() defer cli.mediaConnLock.Unlock() if cli.mediaConnCache == nil || force || time.Now().After(cli.mediaConnCache.Expiry()) { diff --git a/msgsecret.go b/msgsecret.go index 31822f7b..20ddcf2f 100644 --- a/msgsecret.go +++ b/msgsecret.go @@ -83,6 +83,9 @@ type messageEncryptedSecret interface { } func (cli *Client) decryptMsgSecret(msg *events.Message, useCase MsgSecretType, encrypted messageEncryptedSecret, origMsgKey *waCommon.MessageKey) ([]byte, error) { + if cli == nil { + return nil, ErrClientIsNil + } pollSender, err := getOrigSenderFromKey(msg, origMsgKey) if err != nil { return nil, err @@ -102,6 +105,9 @@ func (cli *Client) decryptMsgSecret(msg *events.Message, useCase MsgSecretType, } func (cli *Client) encryptMsgSecret(chat, origSender types.JID, origMsgID types.MessageID, useCase MsgSecretType, plaintext []byte) (ciphertext, iv []byte, err error) { + if cli == nil { + return nil, nil, ErrClientIsNil + } ownID := cli.getOwnID() if ownID.IsEmpty() { return nil, nil, ErrNotLoggedIn diff --git a/newsletter.go b/newsletter.go index 00ba7b1f..8f23a917 100644 --- a/newsletter.go +++ b/newsletter.go @@ -40,6 +40,9 @@ func (cli *Client) NewsletterSubscribeLiveUpdates(ctx context.Context, jid types // // This is not the same as marking the channel as read on your other devices, use the usual MarkRead function for that. func (cli *Client) NewsletterMarkViewed(jid types.JID, serverIDs []types.MessageServerID) error { + if cli == nil { + return ErrClientIsNil + } items := make([]waBinary.Node, len(serverIDs)) for i, id := range serverIDs { items[i] = waBinary.Node{ diff --git a/pair-code.go b/pair-code.go index d1a8497b..a1663b18 100644 --- a/pair-code.go +++ b/pair-code.go @@ -87,6 +87,9 @@ func generateCompanionEphemeralKey() (ephemeralKeyPair *keys.KeyPair, ephemeralK // // See https://faq.whatsapp.com/1324084875126592 for more info func (cli *Client) PairPhone(phone string, showPushNotification bool, clientType PairClientType, clientDisplayName string) (string, error) { + if cli == nil { + return "", ErrClientIsNil + } ephemeralKeyPair, ephemeralKey, encodedLinkingCode := generateCompanionEphemeralKey() phone = notNumbers.ReplaceAllString(phone, "") if len(phone) <= 6 { diff --git a/privacysettings.go b/privacysettings.go index 6c2425a8..2b0f46fc 100644 --- a/privacysettings.go +++ b/privacysettings.go @@ -42,7 +42,7 @@ func (cli *Client) TryFetchPrivacySettings(ignoreCache bool) (*types.PrivacySett // GetPrivacySettings will get the user's privacy settings. If an error occurs while fetching them, the error will be // logged, but the method will just return an empty struct. func (cli *Client) GetPrivacySettings() (settings types.PrivacySettings) { - if cli.MessengerConfig != nil { + if cli == nil || cli.MessengerConfig != nil { return } settingsPtr, err := cli.TryFetchPrivacySettings(false) diff --git a/qrchan.go b/qrchan.go index 5401ac0e..ba29c683 100644 --- a/qrchan.go +++ b/qrchan.go @@ -159,7 +159,9 @@ func (qrc *qrChannel) handleEvent(rawEvt interface{}) { // The last value to be emitted will be a special event like "success", "timeout" or another error code // depending on the result of the pairing. The channel will be closed immediately after one of those. func (cli *Client) GetQRChannel(ctx context.Context) (<-chan QRChannelItem, error) { - if cli.IsConnected() { + if cli == nil { + return nil, ErrClientIsNil + } else if cli.IsConnected() { return nil, ErrQRAlreadyConnected } else if cli.Store.ID != nil { return nil, ErrQRStoreContainsID diff --git a/request.go b/request.go index 2500f95b..e928ec6d 100644 --- a/request.go +++ b/request.go @@ -106,6 +106,9 @@ type infoQuery struct { } func (cli *Client) sendIQAsyncAndGetData(query *infoQuery) (<-chan *waBinary.Node, []byte, error) { + if cli == nil { + return nil, nil, ErrClientIsNil + } if len(query.ID) == 0 { query.ID = cli.generateRequestID() } diff --git a/send.go b/send.go index ae5b0ba0..9ed2ac43 100644 --- a/send.go +++ b/send.go @@ -40,7 +40,7 @@ import ( // msgID := cli.GenerateMessageID() // cli.SendMessage(context.Background(), targetJID, &waProto.Message{...}, whatsmeow.SendRequestExtra{ID: msgID}) func (cli *Client) GenerateMessageID() types.MessageID { - if cli.MessengerConfig != nil { + if cli != nil && cli.MessengerConfig != nil { return types.MessageID(strconv.FormatInt(GenerateFacebookMessageID(), 10)) } data := make([]byte, 8, 8+20+16) @@ -167,6 +167,10 @@ type SendRequestExtra struct { // field in incoming message events to figure out what it contains is also a good way to learn how to // send the same kind of message. func (cli *Client) SendMessage(ctx context.Context, to types.JID, message *waE2E.Message, extra ...SendRequestExtra) (resp SendResponse, err error) { + if cli == nil { + err = ErrClientIsNil + return + } var req SendRequestExtra if len(extra) > 1 { err = errors.New("only one extra parameter may be provided to SendMessage") diff --git a/sendfb.go b/sendfb.go index 660d2669..5c3e52e7 100644 --- a/sendfb.go +++ b/sendfb.go @@ -47,6 +47,10 @@ func (cli *Client) SendFBMessage( metadata *waMsgApplication.MessageApplication_Metadata, extra ...SendRequestExtra, ) (resp SendResponse, err error) { + if cli == nil { + err = ErrClientIsNil + return + } var req SendRequestExtra if len(extra) > 1 { err = errors.New("only one extra parameter may be provided to SendMessage") diff --git a/user.go b/user.go index 58fb6328..66c5b14a 100644 --- a/user.go +++ b/user.go @@ -728,6 +728,9 @@ type UsyncQueryExtras struct { } func (cli *Client) usync(ctx context.Context, jids []types.JID, mode, context string, query []waBinary.Node, extra ...UsyncQueryExtras) (*waBinary.Node, error) { + if cli == nil { + return nil, ErrClientIsNil + } var extras UsyncQueryExtras if len(extra) > 1 { return nil, errors.New("only one extra parameter may be provided to usync()") @@ -844,6 +847,9 @@ func (cli *Client) UpdateBlocklist(jid types.JID, action events.BlocklistChangeA }, }}, }) + if err != nil { + return nil, err + } list, ok := resp.GetOptionalChildByTag("list") if !ok { return nil, &ElementMissingError{Tag: "list", In: "response to blocklist update"}