diff --git a/scripts/hostcfgd b/scripts/hostcfgd index 7f296f70..224b2b9d 100644 --- a/scripts/hostcfgd +++ b/scripts/hostcfgd @@ -1749,6 +1749,84 @@ class SerialConsoleCfg: return +class BannerCfg(object): + """ + Banner Config Daemon + Handles changes in BANNER_MESSAGE table. + 1) Handle change of feature state + 2) Handle change of login message + 3) Handle change of MOTD message + 4) Handle change of logout message + """ + + def __init__(self): + self.cache = {} + + def load(self, banner_messages_config: dict): + """Banner messages configuration + + Force load banner configuration. Login messages should be taken at boot-time by + SSH daemon. + + Args: + banners_message_config: Configured banner messages. + """ + + syslog.syslog(syslog.LOG_INFO, 'BannerCfg: load initial') + + if not banner_messages_config: + banner_messages_config = {} + + # Force load banner messages. + # Login messages show be taken at boot-time by SSH daemon. + state_data = banner_messages_config.get("state", {}) + login_data = banner_messages_config.get("login", {}) + motd_data = banner_messages_config.get("motd", {}) + logout_data = banner_messages_config.get("logout", {}) + + self.banner_message("state", state_data) + self.banner_message("login", login_data) + self.banner_message("motd", motd_data) + self.banner_message("logout", logout_data) + + def banner_message(self, key, data): + """ + Apply banner message handler. + + Args: + cache: Cache to compare/save data. + db: DB instance. + table: DB table that was changed. + key: DB table's key that was triggered change. + data: Read table data. + """ + # Handling state, login/logout and MOTD messages. Data should be a dict + if type(data) != dict: + # Nothing to handle + return + + update_required = False + # Check with cache + for k,v in data.items(): + if v != self.cache.get(k): + update_required = True + break + + if update_required == False: + return + + try: + run_cmd(["systemctl", "restart", "banner-config"], True, True) + except Exception: + syslog.syslog(syslog.LOG_ERR, 'BannerCfg: Failed to restart ' + 'banner-config service') + return + + # Update cache + for k,v in data.items(): + self.cache[k] = v + + class HostConfigDaemon: def __init__(self): self.state_db_conn = DBConnector(STATE_DB, 0) @@ -1803,6 +1881,9 @@ class HostConfigDaemon: # Initialize SerialConsoleCfg self.serialconscfg = SerialConsoleCfg() + # Initialize BannerCfg + self.bannermsgcfg = BannerCfg() + def load(self, init_data): aaa = init_data['AAA'] tacacs_global = init_data['TACPLUS'] @@ -1826,6 +1907,7 @@ class HostConfigDaemon: ntp_servers = init_data.get(swsscommon.CFG_NTP_SERVER_TABLE_NAME) ntp_keys = init_data.get(swsscommon.CFG_NTP_KEY_TABLE_NAME) serial_console = init_data.get('SERIAL_CONSOLE', {}) + banner_messages = init_data.get(swsscommon.CFG_BANNER_MESSAGE_TABLE_NAME) self.aaacfg.load(aaa, tacacs_global, tacacs_server, radius_global, radius_server, ldap_global, ldap_server) self.iptables.load(lpbk_table) @@ -1839,6 +1921,8 @@ class HostConfigDaemon: self.fipscfg.load(fips_cfg) self.ntpcfg.load(ntp_global, ntp_servers, ntp_keys) self.serialconscfg.load(serial_console) + self.bannermsgcfg.load(banner_messages) + self.pamLimitsCfg.update_config_file() # Update AAA with the hostname @@ -1992,6 +2076,10 @@ class HostConfigDaemon: syslog.syslog(syslog.LOG_INFO, 'SERIAL_CONSOLE table handler...') self.serialconscfg.update_serial_console_cfg(key, data) + def banner_handler(self, key, op, data): + syslog.syslog(syslog.LOG_INFO, 'BANNER_MESSAGE table handler...') + self.bannermsgcfg.banner_message(key, data) + def wait_till_system_init_done(self): # No need to print the output in the log file so using the "--quiet" # flag @@ -2059,6 +2147,10 @@ class HostConfigDaemon: self.config_db.subscribe(swsscommon.CFG_NTP_KEY_TABLE_NAME, make_callback(self.ntp_srv_key_handler)) + # Handle BANNER_MESSAGE changes + self.config_db.subscribe(swsscommon.CFG_BANNER_MESSAGE_TABLE_NAME, + make_callback(self.banner_handler)) + syslog.syslog(syslog.LOG_INFO, "Waiting for systemctl to finish initialization") self.wait_till_system_init_done() diff --git a/tests/hostcfgd/hostcfgd_test.py b/tests/hostcfgd/hostcfgd_test.py index 85f80625..9ec3f658 100644 --- a/tests/hostcfgd/hostcfgd_test.py +++ b/tests/hostcfgd/hostcfgd_test.py @@ -354,3 +354,20 @@ def test_load(self): data = {} dns_cfg.load(data) dns_cfg.dns_update.assert_called() + + +class TestBannerCfg: + def test_load(self): + banner_cfg = hostcfgd.BannerCfg() + banner_cfg.banner_message = mock.MagicMock() + + data = {} + banner_cfg.load(data) + banner_cfg.banner_message.assert_called() + + @mock.patch('hostcfgd.run_cmd') + def test_banner_message(self, mock_run_cmd): + banner_cfg = hostcfgd.BannerCfg() + banner_cfg.banner_message(None, {'test': 'test'}) + + mock_run_cmd.assert_has_calls([call(['systemctl', 'restart', 'banner-config'], True, True)])