Nén 12 model thành 1 — liệu có mất gì?

Nén 12 model thành 1 — liệu có mất gì?

Knowledge Distillation giúp dồn sức mạnh cả đội ensemble vào một model duy nhất đủ nhẹ để deploy. Nhưng "nén" không có nghĩa là "miễn phí".

Bạn đang giữ 12 model trên production à?

Giả sử team bạn vừa hoàn thành một hệ thống phân loại — chạy 12 model ensemble, accuracy đẹp như mơ. Rồi đến lúc deploy, DevOps nhìn bill cloud rồi hỏi: "Ủa 12 model chạy song song thiệt hả? Budget tháng này ai chịu?"

Đây không phải chuyện hiếm. Ensemble — tức ghép nhiều model lại để dự đoán — là vũ khí kinh điển trong ML. Nó giảm variance, bắt được nhiều pattern hơn, nhìn metric đẹp lắm. Nhưng mang lên production? Latency nhân đôi, chi phí nhân ba, vận hành thì... hên xui.

Vậy câu hỏi thật sự là: có cách nào giữ được "trí tuệ" của cả ensemble mà chỉ deploy MỘT model nhỏ gọn không?

Khoan — chuyện phức tạp hơn "copy-paste" nhiều

Knowledge Distillation (KD) nói thẳng ra thì: bạn cho một model nhỏ (student) "học lỏm" từ một model lớn hoặc ensemble (teacher), thay vì chỉ học từ dữ liệu gốc.

Điểm then chốt: student không học đáp án đúng/sai (hard labels), mà học phân phối xác suất mềm (soft labels) từ teacher. Hình dung thế này — bạn đang đi thi trắc nghiệm. Hard label chỉ nói "đáp án là C". Soft label nói "C chiếm 70%, B chiếm 20%, A chiếm 8%, D chiếm 2%". Cái 20% kia cho B chứa thông tin quý giá — nó cho biết rằng "B cũng gần đúng đấy, đừng bỏ qua".

Teacher ensemble tạo soft labels bằng temperature scaling — tham số T giúp "làm mềm" phân phối xác suất. T cao thì xác suất phẳng hơn, student học được nhiều nuance hơn. T thấp thì gần giống hard label, student học ít "kiến thức ngầm" hơn.

Theo pipeline từ bài nghiên cứu gốc — 12 model teacher distill vào một student duy nhất — student thu hồi được 53.8% khoảng cách accuracy so với ensemble, trong khi model nhỏ hơn 160 lần. Tức là: bạn mất gần nửa "edge", nhưng đổi lại model nhẹ gấp trăm lần. Đây là trade-off có ý thức, không phải magic.

Hai kịch bản thật từ team Việt Nam

Kịch bản 1 — Startup fintech phân loại giao dịch gian lận

Giả sử team bạn 5 người, đang chạy 8 model XGBoost + Random Forest ensemble để detect fraud. Accuracy trên test set rất ổn, nhưng mỗi request mất 200ms vì phải chạy qua cả 8 model rồi aggregate. Với Knowledge Distillation, bạn dùng ensemble làm teacher, generate soft labels cho toàn bộ training set, rồi train một neural network nhỏ (2-3 layers) làm student. Deploy student duy nhất — latency giảm xuống dưới 20ms.

Trade-off? Student sẽ không bằng ensemble trên mọi edge case, nhưng với phần lớn traffic thông thường, kết quả gần như tương đương. Và sếp không còn nhíu mày mỗi lần nhìn hóa đơn.

Kịch bản 2 — Team NLP xây chatbot nội bộ

Team bạn fine-tune một model lớn (giả sử 7B parameters) để trả lời câu hỏi nội bộ công ty. Chạy ngon trên GPU A100, nhưng muốn deploy lên máy nhân viên thì... hết vía. Distill xuống model 1-2B, student học từ output distributions của teacher. Model nhỏ chạy được trên laptop có GPU tầm trung, trả lời vẫn đủ tốt cho phần lớn câu hỏi thường gặp. Những câu khó, route ngược lên model lớn — hybrid serving.

Thử ngay chiều nay — 4 bước distill ensemble

Bạn có Python và một buổi chiều rảnh? Đủ rồi.

Bước 1: Chuẩn bị ensemble teacher. Nếu chưa có, dùng scikit-learn train nhanh 5-10 model (Random Forest, GBM, SVM) trên dataset của bạn.

Bước 2: Generate soft labels. Với mỗi sample trong training set, lấy output xác suất trung bình từ tất cả model. Đây chính là "kiến thức" bạn muốn chuyển giao.

import numpy as np
soft_labels = np.mean(
    [m.predict_proba(X_train) for m in ensemble_models], axis=0
)

Bước 3: Train student. Dùng PyTorch, tạo một network nhỏ. Loss function kết hợp hai thành phần — KL divergence giữa output student và soft labels (có temperature scaling), cộng cross-entropy loss với hard labels:

loss = alpha * kl_div_loss(student_soft, teacher_soft, T=3.0) \
     + (1 - alpha) * ce_loss(student_out, hard_labels)

Tham số alpha thường bắt đầu ở 0.7 — nghiêng nhiều về soft labels. T (temperature) thường thử trong khoảng 2-5.

Bước 4: Evaluate và so sánh. Chạy student trên test set, đo accuracy, latency, model size. So sánh ba chiều: ensemble teacher vs. student distilled vs. model nhỏ train thường (không distill). Sự khác biệt giữa hai cái sau chính là giá trị distillation mang lại.

Mấy cái bẫy mà team nào cũng dễ dính

Mình từng thấy một team distill xong hí hửng deploy, rồi 2 tuần sau mới phát hiện: student hoạt động cực tệ trên một phân khúc khách hàng nhỏ nhưng quan trọng. Lý do? Ensemble teacher có 3 model chuyên xử lý phân khúc đó, nhưng khi nén lại, student "quên" luôn nhóm thiểu số.

Mấy bẫy phổ biến khác:

Open-source tools và hướng đi tiếp

Nếu bạn làm deep learning, Hugging Face đã hỗ trợ distillation khá smooth qua thư viện transformers. DistilBERT chính là sản phẩm nổi tiếng nhất của KD — model nhỏ hơn, nhanh hơn đáng kể, mà giữ được phần lớn performance so với BERT gốc. Với ML truyền thống (tabular data), scikit-learn kết hợp PyTorch là đủ dùng.

Như mình đã chia sẻ trong các bài trước về quantization và model compression, distillation và quantization là hai hướng bổ trợ nhau. Bạn hoàn toàn có thể distill trước, rồi quantize sau để có model siêu nhẹ — nén hai tầng.

Một dòng mang về

Knowledge Distillation không phải phép màu — bạn LUÔN đánh đổi một phần accuracy. Nhưng nếu bạn đang nuôi một ensemble khổng lồ trên production mà hóa đơn cloud mỗi tháng khiến cả team ngại mở email billing, thì đây là cách đổi "một ít chính xác" lấy "rất nhiều tiền và tốc độ". Cái khó không phải kỹ thuật distill, mà là biết khi nào trade-off đó chấp nhận được — và kiểm chứng nó đủ kỹ trước khi bấm deploy.

Spoiler: không có silver bullet — nhưng có silver lining.

---
Bụi Wire — nghiện đọc release notes lúc 2 giờ sáng

Nguồn tham khảo