Un modèle d’IA médicale affiné sans CUDA tourne sur AMD ROCm. Il répond à des QCM cliniques et explique son raisonnement.

L’IDÉE

La réponse automatique aux questions médicales est un domaine où les enjeux sont très élevés. Un modèle qui choisit avec assurance la mauvaise réponse à un QCM clinique n’est pas seulement faux, il est dangereux. Pourtant, la plupart des travaux en IA médicale open source supposent que l’on possède un GPU NVIDIA. CUDA est la norme. Tout le reste est secondaire.

MedQA est un modèle de réponse aux questions cliniques affiné avec LoRA et entièrement construit sur du Matériel AMD grâce à ROCm. Il prend une question médicale à choix multiple et renvoie à la fois la bonne réponse et une explication clinique du raisonnement. L’ensemble du pipeline d’entraînement, du chargement des données à l’exportation de l’adaptateur, s’exécute sur une AMD Instinct MI300X sans la moindre dépendance à CUDA.

Pourquoi AMD ROCm ? L’AMD Instinct MI300X est une pièce de matériel remarquable : 192 Go de mémoire HBM3 dans un seul appareil. Pour l’affinage de grands modèles de langage, la mémoire vidéo est souvent le facteur limitant : elle dicte la taille du lot, la longueur de séquence et la nécessité de quantification. Avec 192 Go disponibles, le modèle Qwen3-1.7B a été entraîné avec LoRA en fp16 complet, sans aucune quantification 4 bits ou 8 bits.

Plus important encore, l’objectif était de prouver que l’écosystème HuggingFace (Transformers, PEFT, TRL, Accelerate) fonctionne parfaitement sur ROCm. Et c’est le cas. Le même code d’entraînement qui tourne sur CUDA s’exécute sur ROCm en définissant trois variables d’environnement :

os.environ["ROCRVISIBLEDEVICES"] = "0"
os.environ["HIPVISIBLEDEVICES"] = "0"
os.environ["HSAOVERRIDEGFX_VERSION"] = "9.4.2"

C’est tout. Aucune modification de code. Aucun kernel personnalisé. Aucun adaptateur CUDA.

LE JEU DE DONNÉES : MEDMCQA

MedMCQA est un vaste ensemble de questions à choix multiple tirées des examens d’entrée en médecine en Inde (AIIMS, style USMLE). Chaque exemple contient :

### Question:
{question}

### Options:
A) {opa}
B) {opb}
C) {opc}
D) {opd}

### Answer:
{answerletter}) {answertext}

### Explanation:
{explanation}

Pour ce projet, 2 000 échantillons d’entraînement ont été utilisés, un sous-ensemble volontairement réduit pour montrer qu’un affinage significatif est réalisable rapidement. L’entraînement a pris environ 5 minutes sur la MI300X.

LE MODÈLE : QWEN3-1.7B

Le modèle de base est Qwen/Qwen3-1.7B, le dernier petit modèle de langage d’Alibaba. Avec 1,7 milliard de paramètres, il est assez compact pour être affiné à moindre coût, mais assez performant pour produire un raisonnement clinique cohérent. Il prend en charge trustremotecode=True et se charge sans problème avec HuggingFace Transformers.

trustremotecode=True

LE FORMAT DU PROMPT

La cohérence du format de prompt est cruciale pour l’affinage instruction. Chaque exemple d’entraînement et chaque appel d’inférence utilise le même modèle :

### Question:
{question}

### Options:
A) {opa}
B) {opb}
C) {opc}
D) {opd}

### Answer:
{answerletter}) {answertext}

### Explanation:
{explanation}

Pendant l’entraînement, le modèle voit la séquence complète, y compris la réponse et l’explication. Pendant l’inférence, on fournit tout jusqu’à ### Answer:\n et on laisse le modèle compléter à partir de là.

CONFIGURATION LORA

Seulement ~2,2 millions des 1,5 milliard de paramètres du modèle sont entraînés. Cela maintient une faible utilisation mémoire et un entraînement rapide.

from peft import LoraConfig, getpeftmodel, TaskType

lora_config = LoraConfig(
    tasktype=TaskType.CAUSALLM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    targetmodules=["qproj", "v_proj"],
    bias="none",
)

model = getpeftmodel(model, lora_config)
model.printtrainableparameters()
# trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443

PARAMÈTRES D’ENTRAÎNEMENT

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="./outputs",
    numtrainepochs=2,
    perdevicetrainbatchsize=4,
    gradientaccumulationsteps=4,     # effective batch size = 16
    learning_rate=2e-4,
    fp16=True,
    bf16=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    loadbestmodelatend=True,
    gradient_checkpointing=True,
    optim="adamw_torch",
    warmup_ratio=0.05,
    lrschedulertype="cosine",
    report_to="none",
)

Quelques points à noter :

fp16=True, bf16=False

La MI300X fonctionne très bien en fp16 natif. Pas besoin de bf16.

gradient_checkpointing=True

Indispensable : sans cela, le modèle de 1,7 milliard de paramètres ne rentre pas en mémoire, même avec LoRA.

gradientaccumulationsteps=4

Cela donne une taille de lot effective de 16, ce qui est suffisant pour cet ensemble de données. Les charges de travail médicales bénéficient d’un petit taux d’apprentissage stable plutôt que de gros lots.

BOUCLE D’ENTRAÎNEMENT COMPLÈTE

from transformers import DataCollatorForSeq2Seq, Trainer

collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding=True,
    padtomultiple_of=8,
)

trainer = Trainer(
    model=model,
    args=args,
    traindataset=trainds,
    evaldataset=valds,
    data_collator=collator,
)

trainer.train()

# Save adapter + tokenizer
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")

Après l’entraînement, ./outputs contient les poids de l’adaptateur LoRA, quelques mégaoctets au lieu d’un point de contrôle complet de plusieurs gigaoctets.

INFÉRENCE

Au moment de l’inférence, on charge le modèle de base, on attache l’adaptateur LoRA, et on peut éventuellement fusionner les poids :

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.frompretrained("./outputs", trustremote_code=True)
tokenizer.padtoken = tokenizer.eostoken

basemodel = AutoModelForCausalLM.frompretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.float16,
    device_map="auto",
    trustremotecode=True,
)

model = PeftModel.frompretrained(basemodel, "./outputs")
model.eval()

La génération utilise un décodage glouton (do_sample=False) avec une pénalité de répétition pour éviter que le modèle ne boucle :

def generate(prompt, model, tokenizer):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            maxnewtokens=200,
            do_sample=False,
            temperature=1.0,
            repetition_penalty=1.1,
            eostokenid=tokenizer.eostokenid,
            padtokenid=tokenizer.eostokenid,
        )

    newtokens = output[0][inputs["inputids"].shape[-1]:]
    return tokenizer.decode(newtokens, skipspecial_tokens=True)

EXEMPLE DE SORTIE

Question : Lequel des traitements suivants est le premier choix pour une urgence hypertensive ?

A) Oral amlodipine
B) IV labetalol or IV nitroprusside
C) Sublingual nifedipine
D) IM hydralazine

Sortie du modèle :

B) IV labetalol or IV nitroprusside

Explanation:
Intravenous labetalol (beta-blocker) or nitroprusside rapidly reduces blood
pressure in emergency settings. Oral agents act too slowly for hypertensive
emergencies requiring immediate BP control to prevent end-organ damage.

Le modèle ne se contente pas de donner une lettre, il explique pourquoi, ce qui le rend cliniquement utile.

CHARGEMENT DEPUIS LE HUB HUGGINGFACE

L’adaptateur affiné est disponible publiquement. On peut le charger directement sans cloner le dépôt :

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-1.7B", trustremotecode=True
)

base = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trustremotecode=True,
)

model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.mergeandunload()
model.eval()

DÉFIS RENCONTRÉS ET SOLUTIONS

Aucun projet AMD ROCm n’est complet sans une section sur les obstacles. Voici ce qui a été rencontré :

1. bitsandbytes ne compile pas sur ROCm. La solution est simple : ne pas l’utiliser. ROCm ne prend pas en charge bitsandbytes pour le moment. Heureusement, avec 192 Go de HBM3, la quantification n’est pas nécessaire.

2. flash-attention n’est pas disponible sur ROCm. Aucune solution simple n’existe pour l’instant. Le modèle s’est donc passé de flash attention, ce qui n’a posé aucun problème à cette échelle.

3. Le tokenizer padtoken doit être défini sur eostoken. Sans cela, le collateur de données échoue. C’est une pratique courante pour les modèles de langage décodeurs uniquement.

padtoken = eostoken

4. La version de transformers doit être ≥4.40.0. Les versions antérieures ne reconnaissent pas correctement l’architecture Qwen3 lorsqu’on charge avec trustremotecode.

transformers>=4.40.0

À noter : sur le matériel NVIDIA, une quantification 4 bits est souvent nécessaire pour faire tenir un modèle en mémoire. Sur la MI300X avec 192 Go de HBM3, c’est inutile. C’est un véritable avantage matériel : un entraînement plus propre, sans artefacts de quantification.

RÉSULTATS

Seulement 2,2 millions de paramètres entraînés, sur un total de 1,5 milliard, en 5 minutes sur MI300X.
MesureValeur
Paramètres entraînables~2,2M (0,15% du total)
Temps d’entraînement sur MI300X~5 minutes
Taille du jeu de données utilisée2 000 échantillons
Précision de base MedMCQA~45%
FrameworkPyTorch + ROCm 6.1

ESSAYEZ PAR VOUS-MÊME

Pas de GPU ? Pas de problème. La démo Gradio en direct tourne sur HuggingFace Spaces (inférence CPU).

Vous avez du matériel AMD ? Clonez le dépôt et lancez-le en natif :

git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py   # ~5 minutes
python infer.py   # exécute des questions types
python app.py     # lance l’interface Gradio

PROCHAINES ÉTAPES

Ce projet prouve que le pipeline fonctionne. Les prochaines étapes concernent la mise à l’échelle et le renforcement :

Entraînement sur l’intégralité du jeu de données MedMCQA pour obtenir des performances de niveau production. Exploration de l’optimisation flash attention pour ROCm quand elle sera disponible. Amélioration de la qualité des explications via un raffinement supplémentaire du prompt. Ajout d’une interface Gradio plus complète pour les tests utilisateurs. Intégration de métriques d’évaluation clinique robustes.

CONCLUSION

MedQA montre que construire une IA médicale performante et explicable sur du matériel AMD open source est non seulement possible, mais simple. La compatibilité de l’écosystème HuggingFace avec ROCm est vraiment bonne. La marge de mémoire de la MI300X élimine toute une catégorie de problèmes d’ingénierie. Et LoRA fait de l’affinage d’un modèle de 1,7 milliard de paramètres une tâche de 5 minutes.

Si vous construisez sur AMD ROCm et que vous rencontrez des obstacles, les solutions ci-dessus devraient vous faire gagner des heures. Et si vous construisez de l’IA médicale, l’accent mis sur l’explication plutôt que sur la simple précision mérite d’être pris au sérieux.

Construit pour le Hackathon AMD Developer sur lablab.ai · Propulsé par AMD ROCm + l’écosystème HuggingFace

Sources :
  • Hugging Face Blog

L'indépendance de CLODCO est votre garantie.

Pour que l'actualité de l'IA reste sans filtre et sans concession, votre soutien est indispensable. Votre contribution est le seul moteur de notre liberté éditoriale.

Soutenir CLODCO