-
Notifications
You must be signed in to change notification settings - Fork 0
/
利用MAB识别广告标识.py
46 lines (37 loc) · 1.06 KB
/
利用MAB识别广告标识.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gym_bandits
import gym
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
df=pd.DataFrame()
df['Banner_type_0']=np.random.randint(0,2,100000)
df['Banner_type_1']=np.random.randint(0,2,100000)
df['Banner_type_2']=np.random.randint(0,2,100000)
df['Banner_type_3']=np.random.randint(0,2,100000)
df['Banner_type_4']=np.random.randint(0,2,100000)
print(df.head(5))#what do guests choose
num_rounds=100000
num_banner=5
banner_selected=[]
count=np.zeros(num_banner)
sum_rewards=np.zeros(num_banner)
Q=np.zeros(num_banner)
def epsilon_greedy(epsilon):
if np.random.random()<epsilon:
return np.random.choice(num_banner)
else:
return np.argmax(Q)
for i in range(num_rounds):
banner=epsilon_greedy(0.5)
reward=df.values[i,banner]
count[banner]+=1
sum_rewards[banner]+=reward
Q[banner]=sum_rewards[banner]/count[banner]
banner_selected.append(banner)
sns.distplot(banner_selected)
print(len(banner_selected))
print('the optimal is ',np.argmax(Q))
print(Q)
plt.show()
#maybe it is meaningless