Memprediksi Celah Generalisasi dalam Deep Neural Networks

Deep neural networks (DNN) merupakan landasan kemajuan terbaru dalam machine learning, dan bertanggung jawab atas terobosan terbaru dalam berbagai tugas seperti pengenalan gambar, segmentasi gambar, terjemahan mesin, dll. Namun, meskipun terdapat di mana-mana, para peneliti masih berusaha untuk sepenuhnya memahami prinsip-prinsip dasar yang mengaturnya. Secara khusus, teori-teori klasik (misalnya, VC-dimensi dan kompleksitas Rademacher) menunjukkan bahwa fungsi over-parameterisasi harus digeneralisasikan dengan buruk untuk data yang tidak terlihat, namun karya terbaru telah menemukan bahwa fungsi over-parameterized secara besar-besaran (urutan besarnya lebih banyak parameter daripada jumlah data poin) generalisasi dengan baik. Untuk meningkatkan model, diperlukan pemahaman yang lebih baik tentang generalisasi, yang dapat mengarah pada pendekatan yang lebih teoretis, maka ia lebih berprinsip pada desain DNN.

Konsep penting untuk memahami generalisasi adalah kesenjangan generalisasi, yaitu, perbedaan antara kinerja model pada data pelatihan dan kinerjanya pada data tak terlihat yang diambil dari distribusi yang sama. Langkah-langkah signifikan sudah dijalankan untuk mendapatkan batas-batas generalisasi DNN yang lebih baik — batas atas kesenjangan generalisasi — tetapi mereka masih cenderung terlalu melebih-lebihkan kesenjangan generalisasi yang sebenarnya, menjadikannya tidak informatif mengapa beberapa model menggeneralisasi dengan baik. Di sisi lain, gagasan margin - jarak antara titik data dan batas keputusan - telah dipelajari secara luas dalam konteks model dangkal seperti mesin vektor-dukungan, dan ditemukan terkait erat dengan seberapa baik model ini menggeneralisasi untuk data yang tidak terlihat. Karena itu, penggunaan margin untuk mempelajari kinerja generalisasi telah diperluas ke DNN, menghasilkan batas atas teoretis yang sangat halus pada kesenjangan generalisasi, tetapi belum secara signifikan meningkatkan kemampuan untuk memprediksi seberapa baik suatu model digeneralisasikan.

Contoh batas keputusan mesin dukungan-vektor. Hyperplane yang didefinisikan oleh w ∙ x-b = 0 adalah "batas keputusan" dari pengklasifikasi linier ini, yaitu, setiap titik x yang terletak di hyperplane kemungkinan sama-sama berada di salah satu kelas di bawah classifier ini.

Dalam makalah ICLR 2019 kami, "Memprediksi Kesenjangan Generalisasi di Deep Networks dengan Distribusi Margin", kami mengusulkan penggunaan distribusi margin yang dinormalisasi di seluruh lapisan jaringan sebagai prediktor kesenjangan generalisasi. Kami secara empiris mempelajari hubungan antara distribusi margin dan generalisasi dan menunjukkan bahwa, setelah normalisasi jarak, beberapa statistik dasar dari distribusi margin dapat secara akurat memprediksi kesenjangan generalisasi. Kami juga menyediakan semua model yang digunakan sebagai dataset untuk mempelajari generalisasi melalui repositori Github.

Setiap plot sesuai dengan jaringan saraf convolutional yang dilatih pada CIFAR-10 dengan akurasi klasifikasi yang berbeda. Densitas probabilitas (sumbu-y) dari distribusi margin yang dinormalisasi (sumbu-x) pada 4 lapisan jaringan ditunjukkan untuk tiga model yang berbeda dengan generalisasi yang semakin baik (kiri ke kanan). Distribusi margin yang dinormalisasi sangat berkorelasi dengan akurasi tes, yang menunjukkan mereka dapat digunakan sebagai proksi untuk memprediksi kesenjangan generalisasi jaringan. Silakan lihat makalah kami untuk detail lebih lanjut tentang jaringan ini.


Distribusi Margin sebagai Prediktor Generalisasi

Secara intuitif, jika statistik distribusi margin benar-benar dapat memprediksi kinerja generalisasi, skema prediksi sederhana harus bisa membangun hubungan. Karena itu, kami memilih regresi linier sebagai prediktornya. Kami menemukan bahwa hubungan antara kesenjangan generalisasi dan statistik log-transformed dari distribusi margin hampir linear sempurna (lihat gambar di bawah). Bahkan, skema yang diusulkan menghasilkan prediksi yang lebih baik relatif terhadap ukuran generalisasi lain yang ada. Hal ini menunjukkan bahwa distribusi margin mungkin berisi informasi penting tentang seberapa dalam generalisasi model.

Kesenjangan generalisasi terprediksi (sumbu x) vs kesenjangan generalisasi sebenarnya (sumbu y) pada CIFAR-100 + ResNet-32. Titik-titik terletak dekat dengan garis diagonal, yang menunjukkan bahwa nilai prediksi model linear log cocok dengan kesenjangan generalisasi sebenarnya dengan sangat baik.

Dataset Generalisasi Model Mendalam

Selain makalah kami, kami memperkenalkan dataset Deep Model Generalisation (DEMOGEN), yang terdiri dari 756 model mendalam terlatih, bersama dengan pelatihan dan kinerja pengujian mereka pada dataset CIFAR-10 dan CIFAR-100. Model-model tersebut adalah varian dari CNN (dengan arsitektur yang serupa dengan Network-in-Network) dan ResNet-32 dengan berbagai teknik regularisasi populer dan pengaturan hyperparameter, yang menginduksi spektrum luas dari perilaku generalisasi. Sebagai contoh, model CNN yang dilatih CIFAR-10 memiliki akurasi pengujian mulai dari 60% hingga 90,5% dengan kesenjangan generalisasi mulai dari 1% hingga 35%. Untuk detail dataset, silakan lihat makalah kami atau repositori Github. Sebagai bagian dari rilis dataset, kami juga menyertakan utilitas untuk dengan mudah memuat model dan mereproduksi hasil yang disajikan dalam makalah kami.

Harapan kami, penelitian ini dan dataset DEMOGEN akan memberi masyarakat suatu alat yang dapat diakses untuk mempelajari generalisasi dalam deep learning tanpa harus melatih ulang sejumlah besar model. Kami juga berharap bahwa temuan kami akan memotivasi penelitian lebih lanjut dalam prediksi kesenjangan generalisasi dan distribusi margin di lapisan tersembunyi.