Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Changes to support BF16 #77

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
112 changes: 81 additions & 31 deletions riscv_isac/InstructionObject.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import struct

instrs_sig_mutable = ['auipc','jal','jalr']
instrs_sig_update = ['sh','sb','sw','sd','c.sw','c.sd','c.swsp','c.sdsp','fsw','fsd',\
instrs_sig_update = ['sh','sb','sw','sd','c.fsw','c.sw','c.sd','c.swsp','c.sdsp','fsw','fsd',\
'c.fsw','c.fsd','c.fswsp','c.fsdsp']
instrs_no_reg_tracking = ['beq','bne','blt','bge','bltu','bgeu','fence','c.j','c.jal','c.jalr',\
'c.jr','c.beqz','c.bnez', 'c.ebreak'] + instrs_sig_update
Expand All @@ -12,9 +12,9 @@
'fmul.d','fdiv.d','fsqrt.d','fmin.d','fmax.d','fcvt.s.d','fcvt.d.s',\
'feq.d','flt.d','fle.d','fcvt.w.d','fcvt.wu.d','fcvt.l.d','fcvt.lu.d',\
'fcvt.d.l','fcvt.d.lu']
unsgn_rs1 = ['sw','sd','sh','sb','ld','lw','lwu','lh','lhu','lb', 'lbu','flw','fld','fsw','fsd',\
'bgeu', 'bltu', 'sltiu', 'sltu','c.lw','c.ld','c.lwsp','c.ldsp',\
'c.sw','c.sd','c.swsp','c.sdsp','mulhu','divu','remu','divuw',\
unsgn_rs1 = ['sw','sd','sh','sb','ld','lw','lwu','lh','lhu','lb', 'lbu','flw','fld','fsw','fsd','flh','fsh',\
'bgeu', 'bltu', 'sltiu', 'sltu','c.lw','c.lhu','c.lh','c.ld','c.lwsp','c.ldsp',\
'c.sw','c.sd','c.swsp','c.sdsp','c.fsw','mulhu','divu','remu','divuw',\
'remuw','aes64ds','aes64dsm','aes64es','aes64esm','aes64ks2',\
'sha256sum0','sha256sum1','sha256sig0','sha256sig1','sha512sig0',\
'sha512sum1r','sha512sum0r','sha512sig1l','sha512sig0l','sha512sig1h','sha512sig0h',\
Expand All @@ -23,7 +23,7 @@
'andn','orn','xnor','pack','packh','packu','packuw','packw',\
'xperm.n','xperm.b','grevi','aes64ks1i', 'shfli', 'unshfli', \
'aes32esmi', 'aes32esi', 'aes32dsmi', 'aes32dsi','bclr','bext','binv',\
'bset','zext.h','sext.h','sext.b','minu','maxu','orc.b','add.uw','sh1add.uw',\
'bset','zext.h','sext.h','sext.b','zext.b','zext.w','minu','maxu','orc.b','add.uw','sh1add.uw',\
'sh2add.uw','sh3add.uw','slli.uw','clz','clzw','ctz','ctzw','cpop','cpopw','rev8',\
'bclri','bexti','binvi','bseti','fcvt.d.wu','fcvt.s.wu','fcvt.d.lu','fcvt.s.lu']
unsgn_rs2 = ['bgeu', 'bltu', 'sltiu', 'sltu', 'sll', 'srl', 'sra','mulhu',\
Expand Down Expand Up @@ -82,7 +82,10 @@ def __init__(
rm = None,
reg_commit = None,
csr_commit = None,
mnemonic = None
mnemonic = None,
inxFlag = None,
is_sgn_extd = None,
bf16 = None
):

'''
Expand Down Expand Up @@ -117,10 +120,12 @@ def __init__(
self.csr_commit = csr_commit
self.mnemonic = mnemonic
self.is_rvp = False
self.inxFlg = inxFlag
self.rs1_nregs = 1
self.rs2_nregs = 1
self.rs3_nregs = 1
self.rd_nregs = 1
self.bf16 = bf16


def is_sig_update(self):
Expand All @@ -146,6 +151,11 @@ def evaluate_instr_vars(self, xlen, flen, arch_state, csr_regfile, instr_vars):
instr_vars['iflen'] = 32
elif self.instr_name.endswith(".d"):
instr_vars['iflen'] = 64
elif self.instr_name.endswith(".h"):
instr_vars['iflen'] = 16
elif self.instr_name.endswith(".bf16"):
instr_vars['iflen'] = -16
instr_vars['bf16'] = True

# capture the operands
if self.rs1 is not None:
Expand Down Expand Up @@ -175,6 +185,8 @@ def evaluate_instr_vars(self, xlen, flen, arch_state, csr_regfile, instr_vars):
ea_align = (self.instr_addr+(imm_val<<1)) % 4
if self.instr_name == "jalr":
ea_align = (rs1_val + imm_val) % 4
if self.instr_name in ['fsh','flh']:
ea_align = (rs1_val + imm_val) % 2
if self.instr_name in ['sw','sh','sb','lw','lhu','lh','lb','lbu','lwu','flw','fsw']:
ea_align = (rs1_val + imm_val) % 4
if self.instr_name in ['ld','sd','fld','fsd']:
Expand Down Expand Up @@ -286,7 +298,11 @@ def update_arch_state(self, arch_state, csr_regfile):
if commitvalue is not None:
if self.rd[1] == 'x':
arch_state.x_rf[int(commitvalue[1])] = str(commitvalue[2][2:])
print(arch_state.x_rf[int(commitvalue[1])])

elif self.rd[1] == 'f':
#offset = len(commitvalue[2])-len(arch_state.f_rf[int(commitvalue[1])])
#arch_state.f_rf[int(commitvalue[1])] = str(commitvalue[2][offset:])
arch_state.f_rf[int(commitvalue[1])] = str(commitvalue[2][2:])

csr_commit = self.csr_commit
Expand All @@ -310,7 +326,9 @@ def evaluate_instr_var(self, instr_var_name, *args):
rs1 = self.rs1,
rs2 = self.rs2,
rs3 = self.rs3,
is_rvp = self.is_rvp
is_rvp = self.is_rvp,
inxFlag = self.inxFlg,
bf16 = self.bf16
): # could just instr_name suffice?
return func(self, *args)

Expand All @@ -333,14 +351,14 @@ def evaluate_rs1_val_p_ext(self, instr_vars, arch_state):
return self.evaluate_reg_val_p_ext(self.rs1[0], self.rs1_nregs, arch_state)


@evaluator_func("rs1_val", lambda **params: not params['instr_name'] in unsgn_rs1 and not params['is_rvp'] and params['rs1'] is not None and params['rs1'][1] == 'x')
@evaluator_func("rs1_val", lambda **params: not params['instr_name'] in unsgn_rs1 and not params['is_rvp'] and params['rs1'] is not None and params['rs1'][1] == 'x' and not params['inxFlag'])
def evaluate_rs1_val_sgn(self, instr_vars, arch_state):
return self.evaluate_reg_val_sgn(self.rs1[0], instr_vars['xlen'], arch_state)


@evaluator_func("rs1_val", lambda **params: not params['instr_name'] in unsgn_rs1 and not params['is_rvp'] and params['rs1'] is not None and params['rs1'][1] == 'f')
@evaluator_func("rs1_val", lambda **params: not params['instr_name'] in unsgn_rs1 and not params['is_rvp'] and params['rs1'] is not None and (params['rs1'][1] == 'f' or params['inxFlag']))
def evaluate_rs1_val_fsgn(self, instr_vars, arch_state):
return self.evaluate_reg_val_fsgn(self.rs1[0], instr_vars['flen'], arch_state)
return self.evaluate_reg_val_fsgn(self.rs1[0], instr_vars['flen'], instr_vars['xlen'],arch_state)


'''
Expand All @@ -359,14 +377,14 @@ def evaluate_rs2_val_p_ext(self, instr_vars, arch_state):
return self.evaluate_reg_val_p_ext(self.rs2[0], self.rs2_nregs, arch_state)


@evaluator_func("rs2_val", lambda **params: not params['instr_name'] in unsgn_rs2 and not params['is_rvp'] and params['rs2'] is not None and params['rs2'][1] == 'x')
@evaluator_func("rs2_val", lambda **params: not params['instr_name'] in unsgn_rs2 and not params['is_rvp'] and params['rs2'] is not None and params['rs2'][1] == 'x' and not params['inxFlag'])
def evaluate_rs2_val_sgn(self, instr_vars, arch_state):
return self.evaluate_reg_val_sgn(self.rs2[0], instr_vars['xlen'], arch_state)


@evaluator_func("rs2_val", lambda **params: not params['instr_name'] in unsgn_rs2 and not params['is_rvp'] and params['rs2'] is not None and params['rs2'][1] == 'f')
@evaluator_func("rs2_val", lambda **params: not params['instr_name'] in unsgn_rs2 and not params['is_rvp'] and params['rs2'] is not None and (params['rs2'][1] == 'f' or params['inxFlag']))
def evaluate_rs2_val_fsgn(self, instr_vars, arch_state):
return self.evaluate_reg_val_fsgn(self.rs2[0], instr_vars['flen'], arch_state)
return self.evaluate_reg_val_fsgn(self.rs2[0], instr_vars['flen'], instr_vars['xlen'], arch_state)


'''
Expand All @@ -375,9 +393,9 @@ def evaluate_rs2_val_fsgn(self, instr_vars, arch_state):
:param arch_state: Architectural state
:param instr_vars: Dictionary of instruction variables already evaluated
'''
@evaluator_func("rs3_val", lambda **params: params['rs3'] is not None and params['rs3'][1] == 'f')
@evaluator_func("rs3_val", lambda **params: params['rs3'] is not None and (params['rs3'][1] == 'f' or params['inxFlag']))
def evaluate_rs3_val_fsgn(self, instr_vars, arch_state):
return self.evaluate_reg_val_fsgn(self.rs3[0], instr_vars['flen'], arch_state)
return self.evaluate_reg_val_fsgn(self.rs3[0], instr_vars['flen'], instr_vars['xlen'], arch_state)


'''
Expand All @@ -393,12 +411,12 @@ def evaluate_f_ext_sem(self, instr_vars, arch_state, csr_regfile):

f_ext_vars['fcsr'] = int(csr_regfile['fcsr'], 16)

if 'rs1' in instr_vars and instr_vars['rs1'] is not None and instr_vars['rs1'].startswith('f'):
self.evaluate_reg_sem_f_ext(instr_vars['rs1_val'], instr_vars['flen'], instr_vars['iflen'], "1", f_ext_vars)
if 'rs2' in instr_vars and instr_vars['rs2'] is not None and instr_vars['rs2'].startswith('f'):
self.evaluate_reg_sem_f_ext(instr_vars['rs2_val'], instr_vars['flen'], instr_vars['iflen'], "2", f_ext_vars)
if 'rs3' in instr_vars and instr_vars['rs3'] is not None and instr_vars['rs3'].startswith('f'):
self.evaluate_reg_sem_f_ext(instr_vars['rs3_val'], instr_vars['flen'], instr_vars['iflen'], "3", f_ext_vars)
if 'rs1' in instr_vars and instr_vars['rs1'] is not None and (instr_vars['rs1'].startswith('f') or instr_vars['inxFlag']):
self.evaluate_reg_sem_f_ext(instr_vars['rs1_val'], instr_vars['flen'], instr_vars['iflen'], instr_vars['bf16'], "1", f_ext_vars, instr_vars['inxFlag'], instr_vars['xlen'])
if 'rs2' in instr_vars and instr_vars['rs2'] is not None and (instr_vars['rs2'].startswith('f') or instr_vars['inxFlag']):
self.evaluate_reg_sem_f_ext(instr_vars['rs2_val'], instr_vars['flen'], instr_vars['iflen'], instr_vars['bf16'], "2", f_ext_vars, instr_vars['inxFlag'], instr_vars['xlen'])
if 'rs3' in instr_vars and instr_vars['rs3'] is not None and (instr_vars['rs3'].startswith('f') or instr_vars['inxFlag']):
self.evaluate_reg_sem_f_ext(instr_vars['rs3_val'], instr_vars['flen'], instr_vars['iflen'], instr_vars['bf16'], "3", f_ext_vars, instr_vars['inxFlag'], instr_vars['xlen'])

return f_ext_vars

Expand All @@ -416,9 +434,12 @@ def evaluate_reg_val_sgn(self, reg_idx, xlen, arch_state):
return struct.unpack(sgn_sz, bytes.fromhex(arch_state.x_rf[reg_idx]))[0]


def evaluate_reg_val_fsgn(self, reg_idx, flen, arch_state):
fsgn_sz = '>Q' if flen == 64 else '>I'
return struct.unpack(fsgn_sz, bytes.fromhex(arch_state.f_rf[reg_idx]))[0]
def evaluate_reg_val_fsgn(self, reg_idx, flen, xlen, arch_state):
fsgn_sz = '>Q' if flen == 64 and xlen > 32 else '>I'
if self.inxFlg:
return struct.unpack(fsgn_sz, bytes.fromhex(arch_state.x_rf[reg_idx]))[0]
else:
return struct.unpack(fsgn_sz, bytes.fromhex(arch_state.f_rf[reg_idx]))[0]


def evaluate_reg_val_p_ext(self, reg_idx, nregs, arch_state):
Expand All @@ -427,27 +448,54 @@ def evaluate_reg_val_p_ext(self, reg_idx, nregs, arch_state):
reg_hi_val = evaluate_reg_val_unsgn(reg_idx+1, arch_state)
reg_val = (reg_hi_val << 32) | reg_val
return reg_val


def evaluate_reg_sem_f_ext(self, reg_val, flen, iflen, postfix, f_ext_vars):

def sign_extend(self, value, e_bits, v_bits ):
return bin(value | ((1<<e_bits) - (1<<v_bits)))

def twos_comp(val, bits):
"""compute the 2's complement of int value val"""
if (val & (1 << (bits - 1))) != 0:
val = val - (1 << bits) # compute negative value
return val # return positive value as is

def apndSgnBit(bin_val,sgn_bit):
new_bin = list(bin_val)
new_bin[0] = sgn_bit
final_bin = ''.join(new_bin)
return final_bin

def evaluate_reg_sem_f_ext(self, reg_val, flen, iflen, bf16, postfix, f_ext_vars, inxFlag, xlen):
'''
This function expands reg_val and defines the respective sign, exponent and mantissa components
'''
if reg_val is None:
return

if iflen == 32:
if iflen == 16 and not bf16:
e_sz = 5
m_sz = 10
elif iflen == 16:
e_sz = 8
m_sz = 7
elif iflen == 32:
e_sz = 8
m_sz = 23
else:
e_sz = 11
m_sz = 52
bin_val = ('{:0'+str(flen)+'b}').format(reg_val)

if flen > iflen:
f_ext_vars['rs'+postfix+'_nan_prefix'] = int(bin_val[0:flen-iflen],2)
if inxFlag:
if bin_val[32] == '1' :
sgnd_bin_val = bin(reg_val &((1<<flen)-1) | ((1<<flen) - (1<<iflen)))[2:]
f_ext_vars['rs'+postfix+'_sgn_prefix'] = int(sgnd_bin_val[0:iflen],2)
else:
f_ext_vars['rs'+postfix+'_sgn_prefix'] = int(0x0)
else:
bin_val =bin(reg_val &((1<<flen)-1) | ((1<<flen) - (1<<iflen)))[2:]
f_ext_vars['rs'+postfix+'_nan_prefix'] = int(bin_val[0:iflen],2)
bin_val = bin_val[flen-iflen:]

f_ext_vars['fs'+postfix] = int(bin_val[0], 2)
f_ext_vars['fe'+postfix] = int(bin_val[1:e_sz+1], 2)
f_ext_vars['fm'+postfix] = int(bin_val[e_sz+1:], 2)
Expand Down Expand Up @@ -487,4 +535,6 @@ def __str__(self):
line+= ' csr_commit: '+ str(self.csr_commit)
if self.mnemonic:
line+= ' mnemonic: '+ str(self.mnemonic)
if self.bf16:
line+= ' bf16: '+ str(self.bf16)
return line
25 changes: 19 additions & 6 deletions riscv_isac/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self,label,coverpoint,xlen,flen,addr_pairs,sig_addrs,window_size):
self.sig_addrs = sig_addrs
self.window_size = window_size

self.arch_state = archState(xlen,flen)
self.arch_state = archState(xlen,flen,inxFlg)
self.csr_regfile = csr_registers(xlen)
self.stats = statistics(xlen, flen)

Expand Down Expand Up @@ -510,7 +510,7 @@ class archState:
Defines the architectural state of the RISC-V device.
'''

def __init__ (self, xlen, flen):
def __init__ (self, xlen, flen,inxFlg):
'''
Class constructor

Expand All @@ -534,12 +534,17 @@ def __init__ (self, xlen, flen):
else:
self.x_rf = ['0000000000000000']*32

if flen == 32:
if flen == 16:
self.f_rf = ['0000']*32
self.fcsr = 0
elif flen == 32:
self.f_rf = ['00000000']*32

else:
self.f_rf = ['0000000000000000']*32
self.pc = 0
self.flen = flen
self.inxFlg = inxFlg

class statistics:
'''
Expand Down Expand Up @@ -788,6 +793,7 @@ def compute_per_line(queue, event, cgf_queue, stats_queue, cgf, xlen, flen, addr
:param sig_addrs: pairs of start and end addresses for which signature update needs to be checked
:param stats: `stats` object
:param csr_regfile: Architectural state of CSR register file
:param result_count:

:type queue: class`multiprocessing.Queue`
:type event: class`multiprocessing.Event`
Expand All @@ -801,11 +807,13 @@ def compute_per_line(queue, event, cgf_queue, stats_queue, cgf, xlen, flen, addr
:type sig_addrs: (int, int)
:type stats: class `statistics`
:type csr_regfile: class `csr_registers`
:type result_count: int
'''

# List to hold hit coverpoints
hit_covpts = []
rcgf = copy.deepcopy(cgf)
inxFlg = arch_state.inxFlg

# Set of elements to monitor for tracking signature updates
tracked_regs_immutable = set()
Expand Down Expand Up @@ -840,6 +848,8 @@ def compute_per_line(queue, event, cgf_queue, stats_queue, cgf, xlen, flen, addr
enable=True

instr_vars = {}
instr_vars['inxFlag'] = instr.inxFlg
instr_vars['bf16'] = instr.bf16
instr.evaluate_instr_vars(xlen, flen, arch_state, csr_regfile, instr_vars)

old_csr_regfile = {}
Expand Down Expand Up @@ -1287,13 +1297,14 @@ def write_fn_csr_comb_covpt(csr_reg):
stats_queue.close()

def compute(trace_file, test_name, cgf, parser_name, decoder_name, detailed, xlen, flen, addr_pairs
, dump, cov_labels, sig_addrs, window_size, no_count=False, procs=1):
, dump, cov_labels, sig_addrs, window_size, inxFlg, no_count=False, procs=1):
'''Compute the Coverage'''

global arch_state
global csr_regfile
global stats
global cross_cover_queue
global result_count

temp = cgf.copy()
if cov_labels:
Expand All @@ -1314,9 +1325,11 @@ def compute(trace_file, test_name, cgf, parser_name, decoder_name, detailed, xle
dump_f.close()
sys.exit(0)

arch_state = archState(xlen,flen)
arch_state = archState(xlen,flen,inxFlg)
csr_regfile = csr_registers(xlen)
stats = statistics(xlen, flen)
cross_cover_queue = []
result_count = 0

## Get coverpoints from cgf
obj_dict = {} ## (label,coverpoint): object
Expand Down Expand Up @@ -1353,7 +1366,7 @@ def compute(trace_file, test_name, cgf, parser_name, decoder_name, detailed, xle
decoderclass = getattr(instructionObjectfile, "disassembler")
decoder_pm.register(decoderclass())
decoder = decoder_pm.hook
decoder.setup(arch="rv"+str(xlen))
decoder.setup(inxFlag=inxFlg, arch="rv"+str(xlen))

iterator = iter(parser.__iter__()[0])

Expand Down
2 changes: 1 addition & 1 deletion riscv_isac/data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


isa_regex = \
re.compile("^RV(32|64|128)[IE]+[ABCDEFGHJKLMNPQSTUVX]*(Zicsr|Zifencei|Zihintpause|Zam|Ztso|Zkne|Zknd|Zknh|Zkse|Zksh|Zkg|Zkb|Zkr|Zks|Zkn|Zba|Zbc|Zbb|Zbp|Zbr|Zbm|Zbs|Zbe|Zbf|Zbt|Zmmul|Zbpbo){,1}(_Zicsr){,1}(_Zifencei){,1}(_Zihintpause){,1}(_Zmmul){,1}(_Zam){,1}(_Zba){,1}(_Zbb){,1}(_Zbc){,1}(_Zbe){,1}(_Zbf){,1}(_Zbm){,1}(_Zbp){,1}(_Zbpbo){,1}(_Zbr){,1}(_Zbs){,1}(_Zbt){,1}(_Zkb){,1}(_Zkg){,1}(_Zkr){,1}(_Zks){,1}(_Zkn){,1}(_Zknd){,1}(_Zkne){,1}(_Zknh){,1}(_Zkse){,1}(_Zksh){,1}(_Ztso){,1}$")
re.compile("^RV(32|64|128)[IE]+[ABCDEFGHJKLMNPQSTUVX]*(Zfinx|Zfh|Zicsr|Zifencei|Zihintpause|Zam|Ztso|Zkne|Zknd|Zknh|Zkse|Zksh|Zkg|Zkb|Zkr|Zks|Zkn|Zba|Zbc|Zbb|Zbp|Zbr|Zbm|Zbs|Zbe|Zbf|Zbt|Zmmul|Zbpbo){,1}(_Zicsr){,1}(_Zifencei){,1}(_Zihintpause){,1}(_Zfinx){,1}(_Zfh){,1}(_Zmmul){,1}(_Zam){,1}(_Zba){,1}(_Zbb){,1}(_Zbc){,1}(_Zbe){,1}(_Zbf){,1}(_Zbm){,1}(_Zbp){,1}(_Zbpbo){,1}(_Zbr){,1}(_Zbs){,1}(_Zbt){,1}(_Zkb){,1}(_Zkg){,1}(_Zkr){,1}(_Zks){,1}(_Zkn){,1}(_Zknd){,1}(_Zkne){,1}(_Zknh){,1}(_Zkse){,1}(_Zksh){,1}(_Ztso){,1}$")

# regex to find <msb>..<lsb>=<val> patterns in instruction
fixed_ranges = re.compile(
Expand Down
Loading