Skip to content

Commit 1b75c01

Browse files
committed
feat(hpu): backend can now read/write HPU registers using mmapped BAR0 segment
1 parent f5cb6c1 commit 1b75c01

5 files changed

Lines changed: 146 additions & 67 deletions

File tree

backends/tfhe-hpu-backend/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ ipc-channel = "0.18.3"
5353
num-traits = { version = "0.2", optional = true }
5454
clap = { version = "4.4.4", features = ["derive"], optional = true }
5555
clap-num = { version = "1.1.1", optional = true }
56-
nix = { version = "0.29.0", features = ["ioctl", "uio", "fs"] }
56+
nix = { version = "0.29.0", features = ["mman", "ioctl", "uio", "fs"] }
5757

5858
# Dependencies used for rtl_graph features
5959
dot2 = { version = "1.0", optional = true }

backends/tfhe-hpu-backend/src/ffi/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,11 @@ impl HpuHw {
228228
pub fn iop_ack_rd(&mut self) -> u32 {
229229
self.0.ami.iop_ackq_rd()
230230
}
231+
232+
#[cfg(feature = "hw-v80")]
233+
pub fn map_bar_reg(&mut self) -> Result<(), Box<dyn std::error::Error>> {
234+
self.0.ami.map_bar_reg()
235+
}
231236
}
232237

233238
pub struct MemZone(

backends/tfhe-hpu-backend/src/ffi/v80/ami.rs

Lines changed: 99 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
//! AMI driver is used to issue gcq command to the RPU
44
//! Those command are used for configuration and register R/W
55
use lazy_static::lazy_static;
6-
use libc;
76
use std::error::Error;
87
use std::fs::{File, OpenOptions};
98
use std::io::{BufRead, BufReader, Read};
109
use std::os::fd::AsRawFd;
1110
use std::ptr::NonNull;
1211
use std::sync::atomic::{AtomicU32, Ordering};
1312
use std::time::Duration;
13+
use nix::sys::mman::{mmap, munmap, MapFlags, ProtFlags};
14+
use std::num::NonZero;
15+
use std::os::unix::fs::OpenOptionsExt;
1416

1517
const AMI_VERSION_FILE: &str = "/sys/module/ami/version";
1618
const AMI_VERSION_PATTERN: &str = r"3\.2\.\d+-zama";
@@ -81,6 +83,7 @@ impl AmiInfo {
8183

8284
pub struct AmiDriver {
8385
ami_dev: File,
86+
bar_reg_ptr: Option<NonNull<u8>>,
8487
iop_ack_atomic_ptr: NonNull<AtomicU32>,
8588
retry_rate: Duration,
8689
}
@@ -100,6 +103,7 @@ impl AmiDriver {
100103
.read(true)
101104
.write(true)
102105
.create(false)
106+
.custom_flags(libc::O_SYNC)
103107
.open(ami_path)?;
104108

105109
let ami_proc_path = format!("/proc/ami_iop_ack_{}", ami_info.devn);
@@ -109,43 +113,56 @@ impl AmiDriver {
109113
.create(false)
110114
.open(&ami_proc_path)
111115
.unwrap();
112-
unsafe {
113-
let addr = libc::mmap(
114-
std::ptr::null_mut(),
115-
4096,
116-
libc::PROT_READ | libc::PROT_WRITE,
117-
libc::MAP_SHARED,
118-
ami_proc.as_raw_fd(),
116+
117+
let addr = unsafe {
118+
mmap(
119+
None,
120+
NonZero::new(4096 as usize).unwrap(),
121+
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
122+
MapFlags::MAP_SHARED,
123+
&ami_proc,
119124
0,
120-
);
125+
)?
126+
};
121127

122-
if addr == libc::MAP_FAILED {
123-
return Err(format!("mmap on ami_iop_ack_{} failed", ami_info.devn).into());
124-
}
128+
let iop_ack_atomic_ptr: NonNull<AtomicU32> = addr.cast();
125129

126-
Ok(Self {
127-
ami_dev,
128-
iop_ack_atomic_ptr: NonNull::new_unchecked(addr as *mut AtomicU32),
129-
retry_rate,
130-
})
131-
}
130+
Ok(Self {
131+
ami_dev,
132+
bar_reg_ptr: None,
133+
iop_ack_atomic_ptr,
134+
retry_rate,
135+
})
136+
}
137+
138+
pub fn map_bar_reg(&mut self) -> Result<(), Box<dyn std::error::Error>> {
139+
let length: usize = 0x140000;
140+
141+
let map_addr = unsafe {
142+
mmap(
143+
None,
144+
NonZero::new(length).unwrap(),
145+
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE, // Read & Write
146+
MapFlags::MAP_SHARED,
147+
&self.ami_dev,
148+
0, // Offset in BAR0
149+
)?
150+
};
151+
tracing::info!("mapping HPU BAR0 at address -> {:p}", map_addr);
152+
153+
let bar_addr: NonNull<u8> = map_addr.cast();
154+
self.bar_reg_ptr = Some(bar_addr);
155+
156+
Ok(())
132157
}
133158

159+
134160
pub fn munmap_cnt(&self) -> Result<(), Box<dyn Error>> {
135-
let cnt_addr = self.iop_ack_atomic_ptr.as_ptr() as *mut libc::c_void;
161+
let cnt_addr = self.iop_ack_atomic_ptr.cast();
136162
unsafe {
137-
let retc = libc::munmap(cnt_addr, 4096);
138-
139-
if retc == 0 {
140-
Ok(())
141-
} else {
142-
Err(format!(
143-
"Could not unmap the atomic shared cnt (mumap returned {})",
144-
retc
145-
)
146-
.into())
147-
}
163+
munmap(cnt_addr, 4096)?;
148164
}
165+
Ok(())
149166
}
150167

151168
/// Read currently loaded UUID in BAR
@@ -276,24 +293,32 @@ impl AmiDriver {
276293
let data = Box::<u32>::new(0xdeadc0de);
277294
let data_ptr = Box::into_raw(data);
278295

279-
// Populate payload
280-
let payload = AmiPeakPokePayload {
281-
data_ptr,
282-
len: 0x1,
283-
offset: addr as u32,
284-
};
285-
286-
tracing::trace!("AMI: Read request with following payload {payload:x?}");
287-
loop {
288-
let ret = unsafe { ami_peak(ami_fd, &payload) };
289-
match ret {
290-
Err(err) => {
291-
tracing::debug!("AMI: Read failed -> {err:?}");
292-
std::thread::sleep(self.retry_rate);
293-
}
294-
Ok(val) => {
295-
tracing::trace!("AMI: Read ack received {payload:x?} -> {val:?}");
296-
break;
296+
if let Some(base) = self.bar_reg_ptr {
297+
unsafe {
298+
let raw_base = base.as_ptr();
299+
let reg_ptr = raw_base.add((addr + 0x100000).try_into().unwrap()) as *const u32;
300+
*data_ptr = std::ptr::read_volatile(reg_ptr);
301+
}
302+
} else {
303+
// Populate payload
304+
let payload = AmiPeakPokePayload {
305+
data_ptr,
306+
len: 0x1,
307+
offset: addr as u32,
308+
};
309+
310+
tracing::trace!("AMI: Read request with following payload {payload:x?}");
311+
loop {
312+
let ret = unsafe { ami_peak(ami_fd, &payload) };
313+
match ret {
314+
Err(err) => {
315+
tracing::debug!("AMI: Read failed -> {err:?}");
316+
std::thread::sleep(self.retry_rate);
317+
}
318+
Ok(val) => {
319+
tracing::trace!("AMI: Read ack received {payload:x?} -> {val:?}");
320+
break;
321+
}
297322
}
298323
}
299324
}
@@ -307,24 +332,32 @@ impl AmiDriver {
307332
let data = Box::<u32>::new(value);
308333
let data_ptr = Box::into_raw(data);
309334

310-
// Populate payload
311-
let payload = AmiPeakPokePayload {
312-
data_ptr,
313-
len: 0x1,
314-
offset: addr as u32,
315-
};
316-
317-
tracing::trace!("AMI: Write request with following payload {payload:x?}");
318-
loop {
319-
let ret = unsafe { ami_poke(ami_fd, &payload) };
320-
match ret {
321-
Err(err) => {
322-
tracing::debug!("AMI: Write failed -> {err:?}");
323-
std::thread::sleep(self.retry_rate);
324-
}
325-
Ok(val) => {
326-
tracing::trace!("AMI: Write ack received {payload:x?} -> {val:?}");
327-
break;
335+
if let Some(base) = self.bar_reg_ptr {
336+
unsafe {
337+
let raw_base = base.as_ptr();
338+
let reg_ptr = raw_base.add((addr + 0x100000).try_into().unwrap()) as *mut u32;
339+
std::ptr::write_volatile(reg_ptr, value);
340+
}
341+
} else {
342+
// Populate payload
343+
let payload = AmiPeakPokePayload {
344+
data_ptr,
345+
len: 0x1,
346+
offset: addr as u32,
347+
};
348+
349+
tracing::trace!("AMI: Write request with following payload {payload:x?}");
350+
loop {
351+
let ret = unsafe { ami_poke(ami_fd, &payload) };
352+
match ret {
353+
Err(err) => {
354+
tracing::debug!("AMI: Write failed -> {err:?}");
355+
std::thread::sleep(self.retry_rate);
356+
}
357+
Ok(val) => {
358+
tracing::trace!("AMI: Write ack received {payload:x?} -> {val:?}");
359+
break;
360+
}
328361
}
329362
}
330363
}

backends/tfhe-hpu-backend/src/interface/backend.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ pub struct HpuBackend {
2424

2525
// Extracted parameters
2626
pub(crate) params: HpuParameters,
27+
#[cfg(feature = "hw-v80")]
28+
hpu_version_major: u32,
29+
#[cfg(feature = "hw-v80")]
30+
hpu_version_minor: u32,
2731
// Prevent to parse regmap at each polling iteration
2832
#[cfg(not(feature = "hw-v80"))]
2933
workq_addr: u64,
@@ -109,6 +113,18 @@ impl HpuBackend {
109113
let regmap = hw_regmap::FlatRegmap::from_file(&regmap_str);
110114
let mut params = HpuParameters::from_rtl(&mut hpu_hw, &regmap);
111115

116+
#[cfg(feature = "hw-v80")]
117+
let (hpu_version_major, hpu_version_minor) = {
118+
let version_reg = regmap
119+
.register()
120+
.get("info::version")
121+
.expect("Unknown register, check regmap definition");
122+
let hpu_version_val = hpu_hw.read_reg(*version_reg.offset() as u64);
123+
let hpu_version_fields = version_reg.as_field(hpu_version_val);
124+
(*hpu_version_fields.get("major").expect("Unknown field"),
125+
*hpu_version_fields.get("minor").expect("Unknown field"))
126+
};
127+
112128
// In case this is not filled by from_rtl()
113129
if params.ntt_params.min_pbs_nb.is_none() {
114130
params.ntt_params.min_pbs_nb = Some(config.firmware.min_batch_size);
@@ -282,6 +298,10 @@ impl HpuBackend {
282298
hpu_hw,
283299
regmap,
284300
params,
301+
#[cfg(feature = "hw-v80")]
302+
hpu_version_major,
303+
#[cfg(feature = "hw-v80")]
304+
hpu_version_minor,
285305
#[cfg(not(feature = "hw-v80"))]
286306
workq_addr,
287307
#[cfg(not(feature = "hw-v80"))]
@@ -916,6 +936,16 @@ impl HpuBackend {
916936
while self.poll_ack_q()? {}
917937
Ok(())
918938
}
939+
940+
#[cfg(feature = "hw-v80")]
941+
pub(crate) fn get_hpu_version(&self) -> (u32, u32) {
942+
(self.hpu_version_major, self.hpu_version_minor)
943+
}
944+
945+
#[cfg(feature = "hw-v80")]
946+
pub(crate) fn map_bar_reg(&mut self) -> Result<(), Box<dyn std::error::Error>> {
947+
self.hpu_hw.map_bar_reg()
948+
}
919949
}
920950

921951
impl Drop for HpuBackend {

backends/tfhe-hpu-backend/src/interface/device.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ impl HpuDevice {
9494
) where
9595
F: Fn(HpuParameters, &crate::asm::Pbs) -> HpuGlweLookuptableOwned<u64>,
9696
{
97+
// print HPU version
98+
{
99+
let mut backend = self.backend.lock().unwrap();
100+
let (major, minor) = backend.get_hpu_version();
101+
tracing::info!("HPU version -> {}.{}", major, minor);
102+
103+
if major >= 2 && minor >= 3 {
104+
backend.map_bar_reg().unwrap();
105+
}
106+
}
107+
97108
// Properly reset keys
98109
self.bsk_unset();
99110
self.ksk_unset();

0 commit comments

Comments
 (0)