Skip to content

Commit 330d17f

Browse files
committed
updated emulator to handle p-type instructions
1 parent f9d71c5 commit 330d17f

15 files changed

Lines changed: 311 additions & 98 deletions

File tree

gpu/assembler/assembler.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
U_TYPE = {'auipc', 'lli', 'lmi', 'lui'}
2222
C_TYPE = {'csrr'}
2323
J_TYPE = {'jal'}
24-
P_TYPE = {'jpnz', 'prr', 'prw'}
24+
P_TYPE = {'jpnz', 'prsw', 'prlw'}
2525
H_TYPE = {'halt'}
2626

2727
# Instructions without predication
28-
NO_PREDICATE = {'halt', 'prw', 'prr', 'jpnz', 'jal', 'jalr'}
28+
NO_PREDICATE = {'halt', 'jal', 'jalr'}
2929

3030

3131
def load_opcodes(opcode_file: str) -> Dict[str, str]:
@@ -272,7 +272,7 @@ def encode_instruction(self, addr: int, opcode: str, operands: List[str]) -> str
272272
elif opcode in J_TYPE:
273273
required_ops = 2 # rd, imm/label
274274
elif opcode in P_TYPE:
275-
required_ops = 2 # rs1, rs2
275+
required_ops = 3 # prd, rs2, imm
276276
elif opcode in H_TYPE:
277277
required_ops = 0 # no operands
278278
else:
@@ -402,13 +402,33 @@ def encode_instruction(self, addr: int, opcode: str, operands: List[str]) -> str
402402
self.to_binary(imm, 17) + self.to_binary(rd, 6) + op_bits)
403403

404404
elif opcode in P_TYPE:
405-
# P-type: [end, start, pred, rs2[24:19], rs1[18:13], x[12:7], opcode[6:0]]
406-
# Note: No predication for these instructions
407-
rs1 = self.parse_register(operands[0])
408-
rs2 = self.parse_register(operands[1])
405+
# P-type: {end[31], start[30], p[29:25], rs2[24:19], imm[18:13], rd[12:7], opcode[6:0]}
406+
# prd, rs2, imm
407+
if(opcode == 'prsw'): # prsw prs, rs2, imm
408+
predicate = self.parse_predicate(operands[0])
409+
rs2 = self.parse_register(operands[1])
410+
imm = self.parse_immediate(operands[2])
411+
rd = 0
412+
elif(opcode == 'prlw'):
413+
rd = self.parse_predicate(operands[0])
414+
rs2 = self.parse_register(operands[1])
415+
imm = self.parse_immediate(operands[2])
416+
predicate = 0
417+
elif(opcode == 'jpnz'):
418+
predicate = self.parse_predicate(operands[0])
419+
if operands[1] in self.labels:
420+
target = self.labels[operands[1]]
421+
imm = target - addr # PC-relative offset
422+
else:
423+
imm = self.parse_immediate(operands[1])
424+
rs2 = 0
425+
rd = 0
426+
else:
427+
raise ValueError(f"P-type opcode {opcode} found but implemented")
428+
409429
return (self.to_binary(end, 1) + self.to_binary(start, 1) +
410-
'00000' + self.to_binary(rs2, 6) +
411-
self.to_binary(rs1, 6) + '000000' + op_bits)
430+
self.to_binary(predicate, 5) + self.to_binary(rs2, 6) +
431+
self.to_binary(imm, 6) + self.to_binary(rd, 6) + op_bits)
412432

413433
elif opcode in H_TYPE:
414434
# H-type (HALT): [end=1, start=0, 1s[29:7], opcode[6:0]]

gpu/assembler/opcodes.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ csrr 1011000
6666
jal 1100000
6767

6868
jpnz 1101000
69-
prr 1101100
70-
prw 1101101
69+
prsw 1101100
70+
prlw 1101101
7171

7272
halt 1111111

gpu/common/custom_enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ class J_Op(Op):
127127
# P-Type Operations (opcode: 1101xxx)
128128
class P_Op(Op):
129129
JPNZ = Bits(bin='000', length=3) # 000
130+
PRSW = Bits(bin='100', length=3) # 100
131+
PRLW = Bits(bin='101', length=3) # 101
132+
130133

131134
# H-Type Operations (opcode: 1111xxx)
132135
class H_Op(Op):

gpu/emulator/src/emulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def parse_args():
7979
mem = Mem(args.start_pc, str(args.input_file), args.mem_format)
8080

8181
for block_id, warp_id in [(b, w) for b in range(args.num_blocks) for w in range(warps_per_block)]:
82-
pfile = PredicateRegFile(thread_per_warp=32)
82+
pfile = PredicateRegFile(threads_per_warp=32)
8383

8484
rfiles = []
8585
states = []

gpu/emulator/src/instr.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,6 @@ def decode(instruction: Bits, pc: Bits) -> 'Instr':
135135
op = J_Op(funct3)
136136
imm = pred + rs2 + rs1 #rs1 + rs2 + pred #concatenate
137137
ret_instr = J_Instr(op=op, rd=rd, imm=imm, pc=pc)
138-
case Instr_Type.P_TYPE:
139-
# TODO: Create P-Type Instruction Class
140-
raise NotImplementedError("P-Type instruction not implemented yet.")
141138
case Instr_Type.C_TYPE:
142139
op = C_Op(funct3)
143140
rs1 = rs1[0:5]
@@ -147,6 +144,10 @@ def decode(instruction: Bits, pc: Bits) -> 'Instr':
147144
op = F_Op(funct3)
148145
print(f"ftype, funct={op},imm={imm.int}")
149146
ret_instr = F_Instr(op=op, rs1=rs1, rd=rd)
147+
case Instr_Type.P_TYPE:
148+
op = P_Op(funct3)
149+
print(f"ptype, funct={op}, prd={rd}, rs2={rs2}, imm={rs1}")
150+
ret_instr = P_Instr(op, prd=rd, rs2=rs2, imm=rs1, pc=pc)
150151
case Instr_Type.H_TYPE:
151152
op=H_Op(funct3)
152153
print(f"halt, funct={op}, {funct3}")
@@ -480,7 +481,7 @@ def eval(self, csr: CsrRegFile, state: State) -> Optional[int]:
480481
imm_val = self.imm.int # Sign-extended immediate
481482

482483
# Calculate address
483-
addr = rdat1.int + imm_val
484+
addr = rdat1.uint + imm_val
484485
match self.op:
485486
# Memory Write Operations
486487
case S_Op_0.SW: # Store Word (32 bits / 4 bytes)
@@ -645,11 +646,38 @@ def eval(self, csr: CsrRegFile, state: State) -> Optional[int]:
645646
return target_addr & 0xFFFFFFFE # Ensure LSB is zero (word-aligned)
646647

647648
class P_Instr(Instr):
648-
def __init__(self, op: P_Op, rs1: Bits, rs2: Bits, pc: Bits, pred_reg_file: PredicateRegFile) -> None:
649-
raise NotImplementedError(f"P-Type operation {self.op} not implemented yet or doesn't exist.")
649+
def __init__(self, op: P_Op, prd: Bits, rs2: Bits, imm: Bits, pc: Bits) -> None:
650+
super().__init__(op)
651+
self.prd = prd[1:6]
652+
self.rs2 = rs2
653+
self.imm = imm
654+
self.pc = pc # Program counter for JPNZ
650655

651656
def eval(self, csr: CsrRegFile, state: State) -> bool:
652-
raise NotImplementedError(f"P-Type operation {self.op} not implemented yet or doesn't exist.")
657+
# Mark first thread in each warp
658+
is_first_thread = True
659+
if(csr.get_thread_id() % state.pfile.threads_per_warp):
660+
is_first_thread = False
661+
662+
addr = state.rfile.read(self.rs2).uint + self.imm.int
663+
664+
match self.op:
665+
# Jump Pred
666+
case P_Op.JPNZ:
667+
if(state.pfile.read(self.pred).uint != 0): # Has at least a predicate bit high
668+
return self.pc.int + self.imm.int
669+
else: # No predicate is set, let pass
670+
return None
671+
672+
# Predicate Access
673+
case P_Op.PRSW:
674+
if(not is_first_thread): return None
675+
state.memory.write(addr, state.pfile.read(self.pred), 4)
676+
case P_Op.PRLW:
677+
if(not is_first_thread): return None
678+
state.pfile.write(self.prd, state.memory.read(addr, 4))
679+
680+
return None
653681

654682
class H_Instr(Instr): #returns true
655683
def __init__(self, op: H_Op, funct3: Bits, r_pred: Bits = Bits(bin='11111', length=5)) -> None:

gpu/emulator/src/reg_file.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,32 @@ def write(self, rd: Bits, val: Bits) -> None:
2424

2525

2626
class PredicateRegFile(RegFile):
27-
def __init__(self, size: int=32, thread_per_warp: int=32) -> None:
28-
self.thread_per_warp = thread_per_warp
29-
super().__init__(num_regs=size, num_bits_per_reg=thread_per_warp, init_value=0)
30-
self.write(Bits(uint=0, length=5), Bits(uint=(1 << thread_per_warp) - 1, length=thread_per_warp)) # set p0 to all ones
27+
def __init__(self, size: int=32, threads_per_warp: int=32) -> None:
28+
self.threads_per_warp = threads_per_warp
29+
super().__init__(num_regs=size, num_bits_per_reg=threads_per_warp, init_value=0)
30+
self.write(Bits(uint=0, length=5), Bits(uint=(1 << threads_per_warp) - 1, length=threads_per_warp)) # set p0 to all ones
3131

3232
def read(self, rd: Bits) -> Bits:
3333
return self.arr[rd.uint]
3434

3535
def read_thread(self, rd: Bits, thread_id: int) -> bool:
36-
if(thread_id >= self.thread_per_warp):
36+
if(thread_id >= self.threads_per_warp):
3737
# TODO: Remove this behavior after memory system better defined for non linear thread system
38-
thread_id = thread_id % self.thread_per_warp
38+
thread_id = thread_id % self.threads_per_warp
3939

4040
reg_val = self.arr[rd.uint]
4141
return (reg_val.uint >> thread_id) & 0x1
4242

4343
def write_thread(self, rd: Bits, thread_id: int, val: bool) -> None:
44-
if(thread_id >= self.thread_per_warp):
44+
if(thread_id >= self.threads_per_warp):
4545
# TODO: Remove this behavior after memory system better defined for non linear thread system
46-
thread_id = thread_id % self.thread_per_warp
46+
thread_id = thread_id % self.threads_per_warp
4747

4848
reg_val = self.arr[rd.uint]
4949
if(val):
50-
self.arr[rd.uint] = Bits(uint=(reg_val.uint | (1 << thread_id)), length=self.thread_per_warp)
50+
self.arr[rd.uint] = Bits(uint=(reg_val.uint | (1 << thread_id)), length=self.threads_per_warp)
5151
else:
52-
self.arr[rd.uint] = Bits(uint=(reg_val.uint & ~(1 << thread_id)), length=self.thread_per_warp)
52+
self.arr[rd.uint] = Bits(uint=(reg_val.uint & ~(1 << thread_id)), length=self.threads_per_warp)
5353

5454
def write(self, rd: Bits, val: Bits) -> None:
5555
self.arr[rd.uint] = val

0 commit comments

Comments
 (0)