arXiv reaDer
報酬の再計量、再選択、再トレーニングによるプロトタイプ部品ネットワークの改善
Improving Prototypical Part Networks with Reward Reweighing, Reselection, and Retraining
近年では、モデルの出力をデータの特定の特徴に明確に帰属させる、深く解釈可能な画像分類方法の開発に取り組んでいます。これらの方法の 1 つは、プロトタイプ部分ネットワーク (ProtoPNet) です。これは、入力の意味のある部分に基づいて画像を分類しようとします。この方法では解釈可能な分類が得られますが、多くの場合、この方法は画像の偽の部分または矛盾した部分から分類することを学習します。これを改善することを期待して、私たちはヒューマン フィードバックによる強化学習 (RLHF) の最近の開発からインスピレーションを得て、これらのプロトタイプを微調整しています。 CUB-200-2011 データセットの 1 ~ 5 のスケールでプロトタイプの品質に関する人間によるアノテーションを収集することで、偽ではないプロトタイプを識別することを学習する報酬モデルを構築します。完全な RL アップデートの代わりに、再重み付け、再選択、および再トレーニングされたプロトタイプ パーツ ネットワーク (R3-ProtoPNet) を提案します。これにより、ProtoPNet トレーニング ループに 3 つのステップが追加されます。最初の 2 つのステップは報酬ベースの再重み付けと再選択で、プロトタイプを人間のフィードバックに合わせます。最後のステップは、更新されたプロトタイプを使用してモデルの特徴を再調整するための再トレーニングです。 R3-ProtoPNet はプロトタイプの全体的な一貫性と有意義性を向上させますが、単独で使用するとテストの予測精度が低下することがわかりました。複数の R3-ProtoPNet がアンサンブルに組み込まれると、解釈可能性を維持しながらテスト予測パフォーマンスが向上することがわかります。
In recent years, work has gone into developing deep interpretable methods for image classification that clearly attributes a model's output to specific features of the data. One such of these methods is the prototypical part network (ProtoPNet), which attempts to classify images based on meaningful parts of the input. While this method results in interpretable classifications, this method often learns to classify from spurious or inconsistent parts of the image. Hoping to remedy this, we take inspiration from the recent developments in Reinforcement Learning with Human Feedback (RLHF) to fine-tune these prototypes. By collecting human annotations of prototypes quality via a 1-5 scale on the CUB-200-2011 dataset, we construct a reward model that learns to identify non-spurious prototypes. In place of a full RL update, we propose the reweighted, reselected, and retrained prototypical part network (R3-ProtoPNet), which adds an additional three steps to the ProtoPNet training loop. The first two steps are reward-based reweighting and reselection, which align prototypes with human feedback. The final step is retraining to realign the model's features with the updated prototypes. We find that R3-ProtoPNet improves the overall consistency and meaningfulness of the prototypes, but lower the test predictive accuracy when used independently. When multiple R3-ProtoPNets are incorporated into an ensemble, we find an increase in test predictive performance while maintaining interpretability.
updated: Sat Jul 08 2023 03:42:54 GMT+0000 (UTC)
published: Sat Jul 08 2023 03:42:54 GMT+0000 (UTC)
参考文献 (このサイトで利用可能なもの) / References (only if available on this site)
被参照文献 (このサイトで利用可能なものを新しい順に) / Citations (only if available on this site, in order of most recent)
Amazon.co.jpアソシエイト