From 2076857f73c0338a8dba74e17e914f26aba256e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20Mass=C3=A9?= Date: Mon, 2 Jun 2025 23:16:00 +0200 Subject: [PATCH] wip --- app/Containerfile | 12 +++++++++++- app/app.py | 1 + app/requirements.txt | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/app/Containerfile b/app/Containerfile index 365cc01..0679c9b 100644 --- a/app/Containerfile +++ b/app/Containerfile @@ -18,7 +18,17 @@ rm -rf /var/lib/apt/lists/* apt-get clean # Install Python dependencies -pip3 install onnxruntime_gpu +case "$(arch)" in +aarch64) + echo "Downloading onnxruntime_gpu from Nvidia" + wget -q https://nvidia.box.com/shared/static/i7n40ki3pl2x57vyn4u7e9asyiqlnl7n.whl -O onnxruntime_gpu-1.16.0-cp310-cp310-linux_aarch64.whl + pip3 install onnxruntime_gpu-1.16.0-cp310-cp310-linux_aarch64.whl + rm -f onnxruntime_gpu-1.16.0-cp310-cp310-linux_aarch64.whl + ;; +x86_64) + pip3 install onnxruntime_gpu + ;; +esac pip3 install -r requirements.txt EOF diff --git a/app/app.py b/app/app.py index f31e214..6d76644 100644 --- a/app/app.py +++ b/app/app.py @@ -82,6 +82,7 @@ if __name__ == "__main__": ) logger = logging.getLogger(__name__) ort_sess = ort.InferenceSession(MODEL_PATH, providers=PROVIDERS) + logger.info(f"ONNX Runtime device: {ort.get_device()}") nparr = np.fromfile(INPUT_IMAGE_PATH, np.uint8) nparr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) preprocessed, scale, original_image = preprocess(nparr) diff --git a/app/requirements.txt b/app/requirements.txt index f9a892d..a2dc04a 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -1,3 +1,3 @@ opencv-python-headless numpy -onnxruntime +onnxruntime_gpu