-
Notifications
You must be signed in to change notification settings - Fork 0
/
fselection.py
46 lines (36 loc) · 1.67 KB
/
fselection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# Import necessary libraries
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import OneClassSVM
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
def select_features(df, target_column, exclude_column, num_features):
"""
Select features based on their importance for predicting whether an attack occurs.
"""
# Ensure target column is numeric
if df[target_column].dtype == 'object':
df[target_column] = df[target_column].astype('category').cat.codes
# Define features
features = df.drop([target_column, exclude_column], axis=1)
# Fit a random forest classifier to the data
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(features, df[target_column])
# Get feature importances
importances = clf.feature_importances_
# Create a DataFrame of features and importances
features_importances = pd.DataFrame({'Feature': features.columns, 'Importance': importances})
# Sort the DataFrame by importance in descending order and select top features
features_importances = features_importances.sort_values(by='Importance', ascending=False).head(num_features)
# Plot the feature importances
plt.figure(figsize=(10, 6))
sns.barplot(x='Importance', y='Feature', data=features_importances)
plt.xscale('log')
if target_column == "Attack_label":
plt.title('Feature Importances for Attack Detection')
else:
plt.title('Feature Importances for Attack Type Classification')
plt.show()
# Create a new dataframe containing only the most important features
df = df[features_importances['Feature'].tolist() + [target_column]]
return df