[์ ํ์ฐ๊ตฌํ ๊น์ฑํ]
์์ฐ์ด์ฒ๋ฆฌ ๋ถ์ผ์์ pre-trained language model (PLM) ์ ๋ต์ด ํ๋ฅญํ ์ฑ๊ณต์ ๊ฑฐ๋์, ๋ ๋ง์ ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํด ๋ ํฐ PLM์ ๊ฐ๋ฐํ๋ ๊ฒ์ด ํ๋์ ํธ๋๋๋ก ์๋ฆฌ์ก์์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ผ๋ง ์ , NVIDIA์์๋ GPT-3์ ๋ฌด๋ ค 4๋ฐฐ ๊ฐ๊น์ด ๋๋ 530B๊ฐ์ ํ๋ผ๋ฏธํฐ์ง๋ฆฌ ๋ชจ๋ธ์ ๊ณต๊ฐํ์ต๋๋ค.
์ด ๋ชจ๋ธ์ ๊ธฐ์กด์ Megatron-LM ๋ชจ๋ธ๊ณผ Turing-NLG ๋ชจ๋ธ์ ๊ฒฐํฉํ์ฌ, “Megatron-Turing NLG” (MT-NLG) ๋ผ๋ ์ด๋ฆ์ผ๋ก ๋ช
๋ช
๋์ต๋๋ค.
๋ชจ๋ธ์ ํ์ต์ DGX A100 80G ์๋ฒ 560๋๋ฅผ ํ๋์ ํด๋ฌ์คํฐ๋ก ๋ฌถ์ด์ ํ์ตํ๋ค๊ณ ํฉ๋๋ค. ์ ๋ง NVIDIA๊ฐ ์๋๊ณ ์๋ ์คํ๋ ๋ถ๊ฐ๋ฅํ ์ ๋์ ๋ชจ๋ธ์ด๋ค์!
์ด 105๊ฐ์ transformer layer๋ก ๊ตฌ์ฑ๋์ด ์๊ณ , zero-, one- ๊ทธ๋ฆฌ๊ณ few-shot learning task์์ ์ต๊ณ ์ ์ฑ๋ฅ์ ๋ณด์๋ค๊ณ ํฉ๋๋ค.
์ด๋ ๊ฒ ํฐ ๋ชจ๋ธ์ ํ์ตํ๋๋ฐ๋ ๋จ์ํ ๋ง์ ๋, ๋ง์ ๋ฐ์ดํฐ, ๋ง์ GPU๋ง์ ํ์๋ก ํ์ง ์์ต๋๋ค.
์๋์ ๋ฌธ์ ๋ค ๋๋ฌธ์ธ๋ฐ์, ์ฐ์ (1) GPU์ ๋ฉ๋ชจ๋ฆฌ๋ ํ์ ๋์ด ์๊ณ , ์์ฒญ ํฐ hyper parameter๋ฅผ ๋ชจ๋ ํ์ตํ๋๋ฐ๋ ์ ๋ ์ถฉ๋ถํ์ง ์์ต๋๋ค. (2) ํ์ต ์๊ณ ๋ฆฌ์ฆ ์ต์ ํ, ๋ฐ์ดํฐ ์ฒ๋ฆฌ ๋ฐฉ๋ฒ, ์ํํธ์จ์ด-ํ๋์จ์ด ์ต์ ํ๋ฅผ ๋ชจ๋ ๊ณ ๋ คํ์ง ์์ผ๋ฉด, ๋นํ์ค์ ์ผ๋ก ํ์ต์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆด ์ ์์ต๋๋ค.
์ด๋ฒ์ ๊ณต๊ฐ๋ MT-NLG์ ๊ฒฝ์ฐ, Microsoft์ NVIDIA๊ฐ ํ์
ํ์ฌ ์ ๋ก์๋ ๋ชจ๋ธ ํ์ต ํจ์จ์ ๋ฌ์ฑํด์ ๋ง๋ค์ด๋ผ ์ ์์๋ค๊ณ ํฉ๋๋ค ๐
์ฆ, ํ๋์จ์ด์ ์ํํธ์จ์ด์ ์์คํ
๊ตฌ์กฐ๊น์ง ๋ชจ๋ ํ์
ํ๊ณ ์์ด์ผ ํจ์จ์ ์ธ ํ์ต์ด ๊ฐ๋ฅํ๋ค๋ ๊ฑฐ๊ฒ ์ฃ ?
๋ ์์ธํ ์ด์ผ๊ธฐ๋ (๋งํฌ) ์์ ํ์ธํด๋ณด์ค ์ ์์ต๋๋ค.
ํ์ต ์ต์ ํ ๊ด๋ จํด์ ์ถ๊ฐ๋ก ์ฌ๋ฐ๊ฒ ์ฝ์ ๋
ผ๋ฌธ์ด ์์ด์ ๊ณต์ ๋๋ฆฝ๋๋ค.
์ ๋ชฉ(How to train BERT with an academic budget)์์ ์ ์ ์๋ฏ์ด, BERT๊ฐ์ large scaled model๋ค์ ์ด๋ป๊ฒ ์ต์ ํํ์ฌ ์ ๋ ดํ๊ฒ ํ์ตํ ์ ์๋์ง์ ๊ดํ ๋
ผ๋ฌธ์
๋๋ค.
๋ณธ ๋
ผ๋ฌธ์์๋ ๋จผ์ ํ์ต ํ๊ฒฝ๋ถํฐ ์ ํํ์ฌ ์ค์ ํ๋๋ฐ์, (1) 24์๊ฐ ๋ด์ ํ์ต๋ ๊ฒ, (2) 8๊ฐ์ NVIDIA Titan-V GPU (๊ฐ๊ฐ 12GB) ๋ก ํ์ต์ ์๋ํ๋ค๊ณ ํฉ๋๋ค.
์ฐธ๊ณ ๋ก, 8๊ฐ์ Titan-V GPU๋ก 24์๊ฐ ํ์ตํ๋ ๊ฒ์ 4๊ฐ์ RTX 3090 GPU๋ก ํ๋ฃจ, 40GB์ง๋ฆฌ 1๊ฐ์ A100 GPU๋ก 2.4์ผ ํ์ตํ ๊ฒ๊ณผ ์ ์ฌํ๋ค๊ณ ํ๋ค์ ๐
ํ์ต ๋ฐ์ดํฐ๋ ์์ด wikipedia, Toronto BookCorpus๋ก๋ถํฐ ํ๋ํ 16GB์ ํ
์คํธ ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํ๋ค๊ณ ํฉ๋๋ค.
ํ์ต์ BERT-style์ transformer encoder์ MLM objective๋ก ์งํํ์์ต๋๋ค.
๋ํ, sentence classification task๋ฅผ ๋ชฉ์ ์ผ๋ก ํ์ตํ๋ PLM์ด๊ธฐ ๋๋ฌธ์, 128๊ฐ๋ก token ๊ธธ์ด๋ฅผ ์ ํํ์๋๋ฐ, ์ด๋ BERT์ ์ ๋
ผ๋ฌธ์์๋ ์ ์ฉ๋ ๋ฐฉ๋ฒ์ด๋ผ๊ณ ํฉ๋๋ค. (ํ์ต์ ์ด๊ธฐ 90%๋ 127 ํ ํฐ์ผ๋ก, ๋๋จธ์ง 10%๋ 512 ํ ํฐ์ผ๋ก ํ์ต)
ํจ๊ณผ๊ฐ ๋ฏธ๋นํ ๊ฒ์ผ๋ก ์ ์๋ ค์ง ๊ฒ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก, next sentence prediction (NSP) ๋ ํ์ต์์ ์ ๊ฑฐํ๊ณ single sentence๋ง ํ์ตํ์ผ๋ฉฐ, ํ์ต ์๊ฐ์ ํฌํจ๋๋ validation loss๋ฅผ ๊ณ์ฐํ๋ ์๊ฐ๋ง์ ์ค์ด๊ธฐ ์ํด, 30๋ถ๋ง๋ค 0.5%์ validation set๋ง์ ๊ณ์ฐํ๋ค๊ณ ํฉ๋๋ค.
๋ชจ๋ธ์ ์ฌ์ด์ฆ๋ BERT-large์ ๋์ผํ๊ฒ ์ธํ
ํ์ผ๋ฉฐ, DeepSpeed๋ฅผ ํตํด data parallelization, mixed-precision ์ ์ ์ฉํ์ต๋๋ค.
MLM prediction head๋ฅผ sparse token prediction์ผ๋ก ๋ฐ๊พธ์์ผ๋ฉฐ, APEX LayerNorm์ ์ ์ฉํจ์ผ๋ก์จ ํ์ต์ ์ต์ ํํ์ต๋๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก, ์ด๋ ๊ฒ ์ต์ ํ BERT model์ ๊ฒฝ์ฐ, ๋์ผํ batch size (bsz)๋ก ํ์ตํ ๋๋ ๊ธฐ์กด BERT ๋๋น 2๋ฐฐ ์ ๋ ๋น ๋ฅธ ์๋๋ก ํ์ตํ๊ณ , batch size๋ฅผ ์ต๋ํ์ผ๋ก ๋๋ฆฌ์ 2.41์ผ ๋ง์ ํ์ต์ด ๊ฐ๋ฅํ๋ค๊ณ ํฉ๋๋ค ๐
๋จ 24์๊ฐ์ผ๋ก ํ์ต์ ์ ํํ ๊ฒฝ์ฐ, ๊ธฐ์กด PLM๊ณผ ์ ์ฌํ ์ฑ๋ฅ์ ๋ณด์๋ค๊ณ ํ๋ค์! ๐
ํ์ต ์ต์ ํ ๊ด๋ จ๋ ๊ธฐ์ ์ ์์ผ๋ก๋ ๊ณ์ ๋ฐ์ ์ค์
๋๋ค!
๋์ค์๋ ๊ฐ์ธ PC๋ก GPT-3๋ฅผ ํ์ตํ ์ ์๋ ๊ธฐ์ ๋ ๊ฐ๋ฅํ ์ง ๋ชจ๋ฅด๊ฒ ๋ค์ ๐