Skip to content

Commit

Permalink
dump strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
yffbit committed Apr 25, 2024
1 parent 34bea81 commit ab7a22f
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 3 deletions.
4 changes: 3 additions & 1 deletion include/solver/slice_cfr.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class SliceCFR : public Solver {
int *same_hand_ptr[N_PLAYER] {nullptr,nullptr};
vector<vector<PrivateCards>> ranges;
vector<DFSNode> dfs_node;
vector<int> dfs_idx_map;// dfs遍历的每个节点在cuda中的索引
vector<int> dfs_idx_map;// dfs遍历的每个节点在内存中的索引
int node_cnt[N_TYPE];
int n_leaf_node = 0;
int n_player_node = 0;
Expand Down Expand Up @@ -176,6 +176,8 @@ class SliceCFR : public Solver {
void _rm(int player, bool best_cfv=false);
void clear_data(int player);
void clear_root_cfv();
json reConvertJson(const shared_ptr<GameTreeNode>& node, int depth, int max_depth, int &idx, int info);
vector<vector<float>> get_avg_strategy(int idx);
};

#endif // _SLICE_CFR_H_
87 changes: 85 additions & 2 deletions src/solver/slice_cfr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,10 @@ vector<float> SliceCFR::exploitability() {
void SliceCFR::stop() {
stop_flag = true;
}
json SliceCFR::dumps(bool with_status, int depth) {
json ans;
json SliceCFR::dumps(bool with_status, int depth) {// depth:max_round
int idx = 0;
json ans = reConvertJson(tree->getRoot(), 0, depth, idx, 0);
if(idx != dfs_idx) throw runtime_error("dfs idx error");
return std::move(ans);
}
vector<vector<vector<float>>> SliceCFR::get_strategy(shared_ptr<ActionNode> node, vector<Card> cards) {
Expand All @@ -871,3 +873,84 @@ vector<vector<vector<float>>> SliceCFR::get_strategy(shared_ptr<ActionNode> node
vector<vector<vector<float>>> SliceCFR::get_evs(shared_ptr<ActionNode> node, vector<Card> cards) {
return {};
}
vector<vector<float>> SliceCFR::get_avg_strategy(int idx) {
Node &node = player_node[dfs_idx_map[idx]];
int n_hand = hand_size[dfs_node[idx].player], n_act = node.n_act;
int size = n_act * n_hand, i = 0, h = 0, j = 0;
float sum = 0, *strategy_sum = node.data + (size << 1), uni = 1.0 / n_act;
vector<vector<float>> strategy(n_hand, vector<float>(n_act));// [n_hand,n_act]
for(h = 0; h < n_hand; h++) {
sum = 0;
for(i = h; i < size; i += n_hand) sum += strategy_sum[i];
if(sum == 0) {
for(j = 0; j < n_act; j++) strategy[h][j] = uni;
}
else {
for(j = 0, i = h; j < n_act; j++, i += n_hand) strategy[h][j] = strategy_sum[i] / sum;
}
}
return strategy;
}
json SliceCFR::reConvertJson(const shared_ptr<GameTreeNode>& node, int depth, int max_depth, int &idx, int info) {
int curr_idx = idx++;
int type = node->getType(), n_act = 0;
json ans;
if(type == GameTreeNode::ACTION) {
shared_ptr<ActionNode> one_node = dynamic_pointer_cast<ActionNode>(node);
vector<string> actions_str;
if(depth < max_depth) {
int player = one_node->getPlayer();
for(GameActions one_action : one_node->getActions()) actions_str.push_back(one_action.toString());
ans["actions"] = actions_str;
ans["player"] = player;
ans["node_type"] = "action_node";

vector<vector<float>> strategy = get_avg_strategy(curr_idx);
ans["strategy"] = json();
ans["strategy"]["actions"] = actions_str;
json stt;
size_t n_hand = hand_size[player];
int *ptr = hand_card_ptr[player];
for(size_t i = 0; i < n_hand; i++) {
stt[Card::intCard2Str(ptr[i+n_hand])+Card::intCard2Str(ptr[i])] = strategy[i];
}
ans["strategy"]["strategy"] = std::move(stt);

ans["childrens"] = json();
}
vector<shared_ptr<GameTreeNode>> children = one_node->getChildrens();
n_act = children.size();
for(int i = 0; i < n_act; i++) {
json child = reConvertJson(children[i], depth, max_depth, idx, info);
if(depth < max_depth) ans["childrens"][actions_str[i]] = child;
}
}
else if(type == GameTreeNode::CHANCE) {
if((++depth) <= max_depth) ans["node_type"] = "chance_node";
shared_ptr<ChanceNode> chance_node = dynamic_pointer_cast<ChanceNode>(node);
shared_ptr<GameTreeNode> children = chance_node->getChildren();// 不为null
int child_type = children->getType();
n_act = chance_branch[GameTreeNode::gameRound2int(node->getRound())] + 4;
if(child_type == GameTreeNode::ACTION || child_type == GameTreeNode::SHOWDOWN) {// 需要发1张牌
if(depth <= max_depth) ans["deal_number"] = n_act;
if(depth < max_depth) ans["dealcards"] = json();// 需要展开子节点
int j = decode_idx0(info), new_info = 0;
for(int i = 0, k = 0; i < n_act; i++, k++) {// 动作索引i,poss_card索引k
if(j == -1) new_info = code_idx0(k);// 第一次发牌
else {// 第二次发牌,最多发两次牌
if(k == j) k++;// 两次选的一样,则第二次改成下一个
// new_info = code_idx0(max(j,k)) | code_idx1(min(j,k));// idx0为较大值
}
json child = reConvertJson(children, depth, max_depth, idx, new_info);
if(depth < max_depth) ans["dealcards"][Card::intCard2Str(poss_card[k])] = child;
}
}
else {
n_act = n_act*(n_act-1)>>1;
idx += n_act;
if(depth <= max_depth) ans["deal_number"] = n_act;
}
}
// else {}
return std::move(ans);
}

0 comments on commit ab7a22f

Please sign in to comment.