diff --git a/README.md b/README.md index b6df25c6..a2ce69fe 100644 --- a/README.md +++ b/README.md @@ -248,3 +248,38 @@ doxygen docs/doxygen/config.doxyfile See [CONTRIBUTING.md](./.github/CONTRIBUTING.md) for information on contributing. + +## Using Setter and Getter Functions for Persistent Storage + +The coreMQTT library provides setter and getter functions to allow the application to store and restore MQTT state in persistent memory. This is useful for handling QoS2 messages after a device reboot. + +### Setter Function + +The `MQTT_SetOutgoingPublishRecord` function allows the application to set an outgoing publish record in the MQTT context. It can be used to restore the state of the MQTT context after a device reboot. + +```c +MQTTStatus_t MQTT_SetOutgoingPublishRecord( MQTTContext_t * pMqttContext, + uint16_t packetId, + MQTTQoS_t qos, + MQTTPublishState_t publishState ); +``` + +### Getter Function + +The `MQTT_GetOutgoingPublishRecord` function allows the application to get an outgoing publish record from the MQTT context. It can be used to store the state of the MQTT context in persistent memory before a device reboot. + +```c +MQTTStatus_t MQTT_GetOutgoingPublishRecord( const MQTTContext_t * pMqttContext, + uint16_t packetId, + MQTTQoS_t * pQos, + MQTTPublishState_t * pPublishState ); +``` + +### Getting the Failed Packet ID + +The `MQTT_GetFailedPacketId` function allows the application to get the packet ID of the failed packet from the MQTT context. It can be used to handle the situation when the library loses state after a device reboot. + +```c +MQTTStatus_t MQTT_GetFailedPacketId( const MQTTContext_t * pMqttContext, + uint16_t * pPacketId ); +``` diff --git a/source/core_mqtt_state.c b/source/core_mqtt_state.c index 151c24b8..25b71b2e 100644 --- a/source/core_mqtt_state.c +++ b/source/core_mqtt_state.c @@ -1204,3 +1204,93 @@ const char * MQTT_State_strerror( MQTTPublishState_t state ) } /*-----------------------------------------------------------*/ + +MQTTStatus_t MQTT_SetOutgoingPublishRecord( MQTTContext_t * pMqttContext, + uint16_t packetId, + MQTTQoS_t qos, + MQTTPublishState_t publishState ) +{ + MQTTStatus_t status = MQTTSuccess; + + if( ( pMqttContext == NULL ) || ( packetId == MQTT_PACKET_ID_INVALID ) || ( qos == MQTTQoS0 ) ) + { + status = MQTTBadParameter; + } + else + { + status = addRecord( pMqttContext->outgoingPublishRecords, + pMqttContext->outgoingPublishRecordMaxCount, + packetId, + qos, + publishState ); + } + + return status; +} + +MQTTStatus_t MQTT_GetOutgoingPublishRecord( const MQTTContext_t * pMqttContext, + uint16_t packetId, + MQTTQoS_t * pQos, + MQTTPublishState_t * pPublishState ) +{ + MQTTStatus_t status = MQTTSuccess; + size_t recordIndex; + + if( ( pMqttContext == NULL ) || ( packetId == MQTT_PACKET_ID_INVALID ) || ( pQos == NULL ) || ( pPublishState == NULL ) ) + { + status = MQTTBadParameter; + } + else + { + recordIndex = findInRecord( pMqttContext->outgoingPublishRecords, + pMqttContext->outgoingPublishRecordMaxCount, + packetId, + pQos, + pPublishState ); + + if( recordIndex == MQTT_INVALID_STATE_COUNT ) + { + status = MQTTBadParameter; + } + } + + return status; +} + +MQTTStatus_t MQTT_GetFailedPacketId( const MQTTContext_t * pMqttContext, + uint16_t * pPacketId ) +{ + MQTTStatus_t status = MQTTSuccess; + size_t recordIndex; + MQTTQoS_t qos; + MQTTPublishState_t publishState; + + if( ( pMqttContext == NULL ) || ( pPacketId == NULL ) ) + { + status = MQTTBadParameter; + } + else + { + for( recordIndex = 0; recordIndex < pMqttContext->outgoingPublishRecordMaxCount; recordIndex++ ) + { + if( pMqttContext->outgoingPublishRecords[ recordIndex ].packetId != MQTT_PACKET_ID_INVALID ) + { + qos = pMqttContext->outgoingPublishRecords[ recordIndex ].qos; + publishState = pMqttContext->outgoingPublishRecords[ recordIndex ].publishState; + + if( ( qos == MQTTQoS2 ) && ( publishState == MQTTPubRelSend ) ) + { + *pPacketId = pMqttContext->outgoingPublishRecords[ recordIndex ].packetId; + break; + } + } + } + + if( recordIndex == pMqttContext->outgoingPublishRecordMaxCount ) + { + status = MQTTBadParameter; + } + } + + return status; +} diff --git a/source/include/core_mqtt_state.h b/source/include/core_mqtt_state.h index bc229f66..2b9857f6 100644 --- a/source/include/core_mqtt_state.h +++ b/source/include/core_mqtt_state.h @@ -15,10 +15,10 @@ * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ @@ -301,6 +301,59 @@ uint16_t MQTT_PublishToResend( const MQTTContext_t * pMqttContext, const char * MQTT_State_strerror( MQTTPublishState_t state ); /** @endcond */ +/** + * @brief Set an outgoing publish record in the MQTT context. + * + * This function allows the application to set an outgoing publish record in the + * MQTT context. It can be used to restore the state of the MQTT context after a + * device reboot. + * + * @param[in] pMqttContext Initialized MQTT context. + * @param[in] packetId ID of the PUBLISH packet. + * @param[in] qos QoS of the PUBLISH packet. + * @param[in] publishState State of the PUBLISH packet. + * + * @return #MQTTBadParameter, #MQTTNoMemory, or #MQTTSuccess. + */ +MQTTStatus_t MQTT_SetOutgoingPublishRecord( MQTTContext_t * pMqttContext, + uint16_t packetId, + MQTTQoS_t qos, + MQTTPublishState_t publishState ); + +/** + * @brief Get an outgoing publish record from the MQTT context. + * + * This function allows the application to get an outgoing publish record from the + * MQTT context. It can be used to store the state of the MQTT context in persistent + * memory before a device reboot. + * + * @param[in] pMqttContext Initialized MQTT context. + * @param[in] packetId ID of the PUBLISH packet. + * @param[out] pQos QoS of the PUBLISH packet. + * @param[out] pPublishState State of the PUBLISH packet. + * + * @return #MQTTBadParameter, #MQTTSuccess. + */ +MQTTStatus_t MQTT_GetOutgoingPublishRecord( const MQTTContext_t * pMqttContext, + uint16_t packetId, + MQTTQoS_t * pQos, + MQTTPublishState_t * pPublishState ); + +/** + * @brief Get the packet ID of the failed packet. + * + * This function allows the application to get the packet ID of the failed packet + * from the MQTT context. It can be used to handle the situation when the library + * loses state after a device reboot. + * + * @param[in] pMqttContext Initialized MQTT context. + * @param[out] pPacketId ID of the failed packet. + * + * @return #MQTTBadParameter, #MQTTSuccess. + */ +MQTTStatus_t MQTT_GetFailedPacketId( const MQTTContext_t * pMqttContext, + uint16_t * pPacketId ); + /* *INDENT-OFF* */ #ifdef __cplusplus } diff --git a/test/unit-test/core_mqtt_state_utest.c b/test/unit-test/core_mqtt_state_utest.c index fd04d152..65139380 100644 --- a/test/unit-test/core_mqtt_state_utest.c +++ b/test/unit-test/core_mqtt_state_utest.c @@ -15,10 +15,10 @@ * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ @@ -1168,3 +1168,143 @@ void test_MQTT_State_strerror( void ) } /* ========================================================================== */ + +void test_MQTT_SetOutgoingPublishRecord( void ) +{ + MQTTContext_t mqttContext = { 0 }; + MQTTStatus_t status; + const uint16_t PACKET_ID = 1; + const MQTTQoS_t qos = MQTTQoS2; + const MQTTPublishState_t publishState = MQTTPubRelSend; + TransportInterface_t transport; + MQTTFixedBuffer_t networkBuffer = { 0 }; + + transport.recv = transportRecvSuccess; + transport.send = transportSendSuccess; + + MQTTPubAckInfo_t incomingRecords[ MQTT_STATE_ARRAY_MAX_COUNT ] = { 0 }; + MQTTPubAckInfo_t outgoingRecords[ MQTT_STATE_ARRAY_MAX_COUNT ] = { 0 }; + + status = MQTT_Init( &mqttContext, &transport, + getTime, eventCallback, &networkBuffer ); + TEST_ASSERT_EQUAL( MQTTSuccess, status ); + + status = MQTT_InitStatefulQoS( &mqttContext, + outgoingRecords, MQTT_STATE_ARRAY_MAX_COUNT, + incomingRecords, MQTT_STATE_ARRAY_MAX_COUNT ); + TEST_ASSERT_EQUAL( MQTTSuccess, status ); + + /* Test for bad parameters */ + status = MQTT_SetOutgoingPublishRecord( NULL, PACKET_ID, qos, publishState ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + status = MQTT_SetOutgoingPublishRecord( &mqttContext, MQTT_PACKET_ID_INVALID, qos, publishState ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + status = MQTT_SetOutgoingPublishRecord( &mqttContext, PACKET_ID, MQTTQoS0, publishState ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + + /* Success. */ + status = MQTT_SetOutgoingPublishRecord( &mqttContext, PACKET_ID, qos, publishState ); + TEST_ASSERT_EQUAL( MQTTSuccess, status ); + /* Verify the record is added correctly. */ + TEST_ASSERT_EQUAL( PACKET_ID, mqttContext.outgoingPublishRecords[ 0 ].packetId ); + TEST_ASSERT_EQUAL( qos, mqttContext.outgoingPublishRecords[ 0 ].qos ); + TEST_ASSERT_EQUAL( publishState, mqttContext.outgoingPublishRecords[ 0 ].publishState ); +} + +/* ========================================================================== */ + +void test_MQTT_GetOutgoingPublishRecord( void ) +{ + MQTTContext_t mqttContext = { 0 }; + MQTTStatus_t status; + const uint16_t PACKET_ID = 1; + const MQTTQoS_t qos = MQTTQoS2; + const MQTTPublishState_t publishState = MQTTPubRelSend; + MQTTQoS_t retrievedQos; + MQTTPublishState_t retrievedPublishState; + TransportInterface_t transport; + MQTTFixedBuffer_t networkBuffer = { 0 }; + + transport.recv = transportRecvSuccess; + transport.send = transportSendSuccess; + + MQTTPubAckInfo_t incomingRecords[ MQTT_STATE_ARRAY_MAX_COUNT ] = { 0 }; + MQTTPubAckInfo_t outgoingRecords[ MQTT_STATE_ARRAY_MAX_COUNT ] = { 0 }; + + status = MQTT_Init( &mqttContext, &transport, + getTime, eventCallback, &networkBuffer ); + TEST_ASSERT_EQUAL( MQTTSuccess, status ); + + status = MQTT_InitStatefulQoS( &mqttContext, + outgoingRecords, MQTT_STATE_ARRAY_MAX_COUNT, + incomingRecords, MQTT_STATE_ARRAY_MAX_COUNT ); + TEST_ASSERT_EQUAL( MQTTSuccess, status ); + + /* Test for bad parameters */ + status = MQTT_GetOutgoingPublishRecord( NULL, PACKET_ID, &retrievedQos, &retrievedPublishState ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + status = MQTT_GetOutgoingPublishRecord( &mqttContext, MQTT_PACKET_ID_INVALID, &retrievedQos, &retrievedPublishState ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + status = MQTT_GetOutgoingPublishRecord( &mqttContext, PACKET_ID, NULL, &retrievedPublishState ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + status = MQTT_GetOutgoingPublishRecord( &mqttContext, PACKET_ID, &retrievedQos, NULL ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + + /* No record found. */ + status = MQTT_GetOutgoingPublishRecord( &mqttContext, PACKET_ID, &retrievedQos, &retrievedPublishState ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + + /* Success. */ + addToRecord( mqttContext.outgoingPublishRecords, 0, PACKET_ID, qos, publishState ); + status = MQTT_GetOutgoingPublishRecord( &mqttContext, PACKET_ID, &retrievedQos, &retrievedPublishState ); + TEST_ASSERT_EQUAL( MQTTSuccess, status ); + /* Verify the record is retrieved correctly. */ + TEST_ASSERT_EQUAL( qos, retrievedQos ); + TEST_ASSERT_EQUAL( publishState, retrievedPublishState ); +} + +/* ========================================================================== */ + +void test_MQTT_GetFailedPacketId( void ) +{ + MQTTContext_t mqttContext = { 0 }; + MQTTStatus_t status; + const uint16_t PACKET_ID = 1; + const MQTTQoS_t qos = MQTTQoS2; + const MQTTPublishState_t publishState = MQTTPubRelSend; + uint16_t retrievedPacketId; + TransportInterface_t transport; + MQTTFixedBuffer_t networkBuffer = { 0 }; + + transport.recv = transportRecvSuccess; + transport.send = transportSendSuccess; + + MQTTPubAckInfo_t incomingRecords[ MQTT_STATE_ARRAY_MAX_COUNT ] = { 0 }; + MQTTPubAckInfo_t outgoingRecords[ MQTT_STATE_ARRAY_MAX_COUNT ] = { 0 }; + + status = MQTT_Init( &mqttContext, &transport, + getTime, eventCallback, &networkBuffer ); + TEST_ASSERT_EQUAL( MQTTSuccess, status ); + + status = MQTT_InitStatefulQoS( &mqttContext, + outgoingRecords, MQTT_STATE_ARRAY_MAX_COUNT, + incomingRecords, MQTT_STATE_ARRAY_MAX_COUNT ); + TEST_ASSERT_EQUAL( MQTTSuccess, status ); + + /* Test for bad parameters */ + status = MQTT_GetFailedPacketId( NULL, &retrievedPacketId ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + status = MQTT_GetFailedPacketId( &mqttContext, NULL ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + + /* No record found. */ + status = MQTT_GetFailedPacketId( &mqttContext, &retrievedPacketId ); + TEST_ASSERT_EQUAL( MQTTBadParameter, status ); + + /* Success. */ + addToRecord( mqttContext.outgoingPublishRecords, 0, PACKET_ID, qos, publishState ); + status = MQTT_GetFailedPacketId( &mqttContext, &retrievedPacketId ); + TEST_ASSERT_EQUAL( MQTTSuccess, status ); + /* Verify the packet ID is retrieved correctly. */ + TEST_ASSERT_EQUAL( PACKET_ID, retrievedPacketId ); +}