
In the high-stakes world of medical diagnostics, AI models that confidently answer clinical questions need to be not just accurate, but also explainable. Historically, developing such cutting-edge AI has often meant relying on NVIDIA GPUs and their CUDA ecosystem, leaving other powerful hardware platforms as an afterthought. But what if you could achieve top-tier clinical AI performance, complete with robust reasoning, entirely on AMD hardware?
Enter MedQA, a groundbreaking project demonstrating that sophisticated medical question-answering models can be built and fine-tuned on AMD ROCm, without a single line of CUDA code. This initiative, developed for the AMD Developer Hackathon on lablab.ai, leverages the immense capabilities of the AMD Instinct MI300X GPU to train a LoRA-adapted Qwen3-1.7B model on the MedMCQA dataset. It’s a complete walkthrough, proving the viability and power of AMD for critical AI applications.
Revolutionizing Clinical AI with AMD ROCm
The core idea behind MedQA addresses a crucial gap: while medical question answering demands absolute precision, much of the open-source AI community defaults to NVIDIA infrastructure. A model that incorrectly answers a clinical MCQ isn’t just a misstep; it could have serious real-world implications. MedQA tackles this by providing both the correct answer and a comprehensive clinical explanation, all powered by AMD.
The star of the show is the AMD Instinct MI300X, an exceptional piece of hardware boasting an incredible 192 GB of HBM3 memory on a single device. This colossal VRAM capacity fundamentally changes the game for Large Language Model (LLM) fine-tuning, alleviating the most common constraint: memory limitations. With such ample memory, MedQA was trained in full fp16 precision, completely bypassing the need for complex 4-bit or 8-bit quantization hacks.
Crucially, this project showcases the seamless compatibility of the HuggingFace ecosystem—including Transformers, PEFT, TRL, and Accelerate—with ROCm. The same training code designed for CUDA runs perfectly on ROCm with just three simple environment variables configured. This means no custom kernels, no complex code changes, and no compatibility shims are required, streamlining development on AMD platforms.
Building MedQA: Dataset, Model, and LoRA Magic
For its foundation, MedQA utilizes the MedMCQA dataset, a rich collection of multiple-choice questions sourced from Indian medical entrance exams, similar to USMLE-style questions. To quickly demonstrate the fine-tuning process, a deliberately small subset of 2,000 training samples was used. This focused approach allowed for meaningful fine-tuning in an astonishingly short period.
The base model chosen for MedQA is Qwen/Qwen3-1.7B, Alibaba’s latest small-scale language model. At 1.7 billion parameters, it strikes an ideal balance, being compact enough for cost-effective fine-tuning while retaining sufficient capability to generate coherent and clinically relevant reasoning. Its compatibility with HuggingFace Transformers makes it an excellent choice for this project.
Consistency in prompt formatting is paramount for effective instruction fine-tuning, and MedQA employs a precise template for every training example and inference call. This structure ensures the model learns to understand the context of the question, options, answer, and explanation distinctly. During training, the model processes the full sequence, including the answer and explanation, while during inference, it intelligently completes the explanation after being provided with the question and options.
Instead of fine-tuning all 1.7 billion parameters, MedQA leverages LoRA (Low-Rank Adaptation) through the PEFT library. LoRA efficiently injects small, trainable rank-decomposition matrices into the attention layers, keeping the vast majority of the base model’s weights frozen. This technique dramatically reduces the number of trainable parameters to approximately 2.2 million out of 1.5 billion total parameters, or roughly 0.1443%.
This efficient parameter tuning, combined with the MI300X’s raw power, meant that training the model on the 2,000 MedMCQA samples took approximately 5 minutes. This incredibly fast turnaround demonstrates the potential for rapid iteration and development in medical AI on AMD hardware. The training arguments included essential settings like `per_device_train_batch_size=4`, `gradient_accumulation_steps=4` (leading to an effective batch size of 16), and `fp16=True` for optimal performance.
From Training to Insightful Inference
The full training loop integrates standard HuggingFace components like `DataCollatorForSeq2Seq` and the `Trainer` class, making the process familiar to anyone in the ML community. After just two epochs, the LoRA adapter weights, which are only a few megabytes, are saved, ready to be deployed. This small file size dramatically simplifies storage and deployment compared to saving an entire multi-gigabyte model checkpoint.
During inference, the process is straightforward: the base Qwen3-1.7B model is loaded, and the small LoRA adapter weights are attached. The model then operates in evaluation mode, ready to generate responses. The generation function employs greedy decoding with a repetition penalty to ensure coherent and non-looping outputs, focusing on clarity and precision.
The results are truly impressive. When presented with a question like “Which of the following is the first-line treatment for hypertensive emergency?”, the model not only selects the correct option, B) IV labetalol or IV nitroprusside, but also provides a concise and clinically sound explanation. This ability to articulate reasoning is what elevates MedQA beyond simple answer recall, making it a genuinely useful tool in a clinical context.
For ease of access, the fine-tuned LoRA adapter is publicly available on the HuggingFace Hub. This means developers can load it directly, merging it with the base Qwen3-1.7B model to get a powerful clinical AI ready for immediate use. The entire setup is designed for maximum accessibility and deployment flexibility.
Overcoming Challenges and Looking Ahead
No cutting-edge project is without its hurdles, and MedQA encountered a few typical AMD ROCm-related quirks. Issues with `flash_attn` not functioning as expected and `bitsandbytes` not being available were notable. However, these challenges inadvertently highlighted a key advantage of the MI300X: its enormous 192 GB HBM3 memory. The sheer volume of VRAM made memory-saving techniques like 4-bit quantization unnecessary, simplifying the training pipeline and avoiding potential quantization artifacts.
The core results are compelling: training required only ~2.2 million trainable parameters (a mere 0.15% of the total model), took just ~5 minutes on the MI300X for 2,000 samples, and built upon a baseline MedMCQA accuracy of ~45%. The entire framework is built on PyTorch and ROCm 6.1, showcasing robust compatibility.
For those eager to experience MedQA, a live Gradio demo is available on HuggingFace Spaces for CPU inference. AMD hardware owners can easily clone the repository, install dependencies, and run `train.py` (taking only about 5 minutes) and `infer.py` locally. This hands-on opportunity demonstrates the ease of setup and the power of the platform.
This project unequivocally proves that a capable, explainable medical AI can be built on open-source AMD hardware. Future work involves scaling the dataset, hardening the model for production, and exploring advanced techniques like Retrieval-Augmented Generation (RAG) for even more robust, fact-grounded responses. The compatibility of the HuggingFace ecosystem with ROCm is remarkably strong, and the MI300X’s memory headroom truly removes significant engineering bottlenecks.
Source: Hugging Face Blog