안녕하세요,
저희가 pytorch onnx 모델 graph 를 만들려고 하는데 아래와 같은 에러 메세지가 나옵니다:
2023-11-03T03:17:29.002293Z INFO furiosa_rt_core::driver::event_driven::coord: [Runtime-0] created Sess-aae8a100 using npu:0:0, npu:1:0
2023-11-03T03:17:29.021616Z INFO furiosa_rt_core::driver::event_driven::coord: [Sess-aae8a100] compiling the model (target: warboy-b0, 64dpes, size: 12.2 KB)
[1/6] 🔍 Compiling from onnx to dfg
ERROR: verification of operator failed:
Reshape
name:
/_mlp
/_mlp.2
/MatMul_output_0
input tensors: 1
input tensor 20: [32x1x1x1] NxHxWxC, 128 B, f32
source: unknown
total bytes: 128
output tensors: 1
output tensor 8: [1], 2048 B, f32
source: {/_mlp/_mlp.2/MatMul_output_0}
total bytes: 2048: invalid shape: LabeledShape { inner: LabeledShape { axis_size_map: AxisSizeMap { inner: {Batch: 32, Height: 1, Width: 1, Channel: 1} }, filter_hw_broadcasted: false, filter_diag_broadcasted: false } }, UnlabeledShape { inner: UnlabeledShape { sizes: [1] } }
2023-11-03T03:17:29.058639Z INFO furiosa_rt_core::driver::event_driven::coord: compilation failed, unloading npu:0:0, npu:1:0
================================================================================
Compilation Failure Report
================================================================================
- furiosa-runtime version: 0.10.2 (rev: a45bb1a0b built at 2023-10-12T06:41:21Z)
- furiosa-compiler version: 0.10.1 (rev: 8b00177dc built at 2023-10-12T06:26:59Z)
- libhal version: 0.11.0 (rev: 43c901f built at 2023-04-19T14:04:55Z)
혹시 이 문제를 해결하려면 어떻게 해야 하나요? 코드는 아래와 같습니다:
in_dim = model._nn_config.job_emb_dim + model._nn_config.comm_emb_dim * 3
in_sample = torch.randn(in_dim)
onnx_filepath = "./wysched/run/policy_network.onnx"
print("Converting Policy Network to ONNX model...")
torch.onnx.export(
policy_network,
(in_sample,),
onnx_filepath,
opset_version=13,
do_constant_folding=True,
input_names=["embedding"],
output_names=["policy_out"],
)
print("onnx export Done.")
onnx_policy_network = InferenceSession(onnx_filepath, providers=["CPUExecutionProvider"])
onnx_model = onnx.load_model(onnx_filepath)
onnx_model = optimize_model(onnx_model)
if os.path.exists("./wysched/run/calibration_dataset.pkl"):
with open("./wysched/run/calibration_dataset.pkl", "rb") as f:
calibration_dataset = pickle.load(f)
else:
calibration_dataset = torch.as_tensor(create_dataset(
env, preprocessor, model, onnx_policy_network, orders, work_days),
dtype=torch.float32)
with open("./wysched/run/calibration_dataset.pkl", "wb") as f:
pickle.dump(calibration_dataset, f)
print('Calibration data collected with shape: ', calibration_dataset.shape)
if os.path.exists("./wysched/run/ranges.json"):
with open("./wysched/run/ranges.json", "rb") as f:
ranges = pickle.load(f)
else:
# load dataset
calibration_dataloader = torch.utils.data.DataLoader(
calibration_dataset,
batch_size=1,
)
calibrator = Calibrator(onnx_model, CalibrationMethod.MIN_MAX_ASYM)
for calibration_data in tqdm.tqdm(
calibration_dataloader,
desc='Calibration',
unit='observations',
mininterval=0.5
):
calibrator.collect_data([[calibration_data.squeeze().numpy()]])
ranges = calibrator.compute_range()
with open("./wysched/run/ranges.json", "wb") as f:
pickle.dump(ranges, f)
graph = quantize(onnx_model, ranges)
with furiosa.runtime.session.create(graph) as session:
# run env