forked from rdspring1/Count-Sketch-Optimizers
-
Notifications
You must be signed in to change notification settings - Fork 1
/
cms.py
109 lines (102 loc) · 2.58 KB
/
cms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from cupy_kernel import cupyKernel
import numpy as np
import math
kernel = '''
extern "C"
__inline__ __device__
int hash(int value, int range, int a, int b)
{
int h = a * value + b;
h ^= h >> 16;
h *= 0x85ebca6b;
h ^= h >> 13;
h *= 0xc2b2ae35;
h ^= h >> 16;
return h % range;
}
extern "C"
__inline__ __device__
float minimum(float a, float b, float c)
{
return fminf(fminf(a,b),c);
}
extern "C"
__inline__ __device__
float update_retrieve(float* mem,
float* result,
const int N,
const int D,
const long index,
const float value)
{
int a = 994443;
int b = 609478;
const int hash_idx = hash(index, N, a, b) * D + threadIdx.x;
mem[hash_idx] += value;
return mem[hash_idx];
}
extern "C"
__inline__ __device__
float cms_update_retrieve(float* mem,
float* result,
const int N,
const int W,
const int D,
const long index,
const float value)
{
float r[3];
int a[3] = {994443, 4113759, 9171025};
int b[3] = {609478, 2949676, 2171464};
for(int idx = 0; idx < 3; ++idx)
{
const int hash_idx = idx*W + hash(index, N, a[idx], b[idx]) * D + threadIdx.x;
mem[hash_idx] += value;
r[idx] = mem[hash_idx];
}
return minimum(r[0], r[1], r[2]);
}
extern "C"
__global__
void hash_update_retrieve(const long* indices,
const float* values,
float* mem,
float* result,
const int N,
const int W,
const int D)
{
if(threadIdx.x < D)
{
const int idx = blockIdx.x * D + threadIdx.x;
const float value = values[idx];
const long index = indices[blockIdx.x];
result[idx] = cms_update_retrieve(mem, result, N, W, D, index, value);
}
}
'''
class CountMinSketch:
def __init__(self, N, D, sketch_size=0.20):
self.N = N
self.D = D
self.blk_size = math.ceil(D // 32) * 32
self.range = int(N*sketch_size/3.)
self.width = self.range * D
self.kernel = cupyKernel(kernel, "hash_update_retrieve")
self.cms = torch.zeros(3, self.range, D).float().cuda()
print(N, "CMS", self.cms.size())
def update(self, indices, values, size):
M, D = values.size()
result = torch.zeros(values.size()).float().cuda()
self.kernel(grid=(M,1,1),
block=(self.blk_size,1,1),
args=[indices.data_ptr(),
values.data_ptr(),
self.cms.data_ptr(),
result.data_ptr(),
self.range,
self.width,
self.D],
strm=torch.cuda.current_stream().cuda_stream)
return torch.cuda.sparse.FloatTensor(indices, result, size)