[SSUDA] 트랜스포머 실습하기 with 데이콘 코드 유사성 대회
• Seong Yeon Kim • 26 min read
SSUDA transformer Deep Learning natural language DACON
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
구글 드라이브와 연동해 파일을 불러옵니다.
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import warnings
warnings.filterwarnings("ignore")
path = '/content/drive/MyDrive/coding/'
train = pd.read_csv(path + 'sample_train.csv')
test = pd.read_csv(path + 'test.csv')
sample_submission = pd.read_csv(path + 'sample_submission.csv')
train.head()
code1 | code2 | similar | |
---|---|---|---|
0 | flag = "go"\ncnt = 0\nwhile flag == "go":\n ... | # Python 3+\n#--------------------------------... | 1 |
1 | b, c = map(int, input().split())\n\nprint(b * c) | import numpy as np\n\nn = int(input())\na = np... | 0 |
2 | import numpy as np\nimport sys\nread = sys.std... | N, M = map(int, input().split())\nif M%2 != 0:... | 0 |
3 | b, c = map(int, input().split())\n\nprint(b * c) | n,m=map(int,input().split())\nh=list(map(int,i... | 0 |
4 | s=input()\nt=input()\nans=0\nfor i in range(le... | import math\na,b,h,m=map(int,input().split())\... | 0 |
필수 패키지를 불러오고 데이터를 불러옵니다.
!pip install transformers knockknock
from knockknock import discord_sender
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
|████████████████████████████████| 4.2 MB 5.3 MB/s
Collecting knockknock
Downloading knockknock-0.1.8.1-py3-none-any.whl (28 kB)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.3)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
|████████████████████████████████| 6.6 MB 48.3 MB/s
Collecting pyyaml>=5.1
Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
|████████████████████████████████| 596 kB 47.0 MB/s
Collecting huggingface-hub<1.0,>=0.1.0
Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
|████████████████████████████████| 86 kB 6.1 MB/s
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Collecting keyring
Downloading keyring-23.5.1-py3-none-any.whl (33 kB)
Collecting yagmail>=0.11.214
Downloading yagmail-0.15.277-py2.py3-none-any.whl (17 kB)
Collecting twilio
Downloading twilio-7.9.1-py2.py3-none-any.whl (1.4 MB)
|████████████████████████████████| 1.4 MB 45.6 MB/s
Collecting matrix-client
Downloading matrix_client-0.4.0-py2.py3-none-any.whl (43 kB)
|████████████████████████████████| 43 kB 2.4 MB/s
Collecting python-telegram-bot
Downloading python_telegram_bot-13.12-py3-none-any.whl (511 kB)
|████████████████████████████████| 511 kB 70.2 MB/s
Collecting premailer
Downloading premailer-3.10.0-py2.py3-none-any.whl (19 kB)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)
Collecting jeepney>=0.4.2
Downloading jeepney-0.8.0-py3-none-any.whl (48 kB)
|████████████████████████████████| 48 kB 6.2 MB/s
Collecting SecretStorage>=3.2
Downloading SecretStorage-3.3.2-py3-none-any.whl (15 kB)
Collecting cryptography>=2.0
Downloading cryptography-37.0.2-cp36-abi3-manylinux_2_24_x86_64.whl (4.0 MB)
|████████████████████████████████| 4.0 MB 45.9 MB/s
Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.7/dist-packages (from cryptography>=2.0->SecretStorage>=3.2->keyring->knockknock) (1.15.0)
Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.12->cryptography>=2.0->SecretStorage>=3.2->keyring->knockknock) (2.21)
Requirement already satisfied: urllib3~=1.21 in /usr/local/lib/python3.7/dist-packages (from matrix-client->knockknock) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: lxml in /usr/local/lib/python3.7/dist-packages (from premailer->yagmail>=0.11.214->knockknock) (4.2.6)
Requirement already satisfied: cachetools in /usr/local/lib/python3.7/dist-packages (from premailer->yagmail>=0.11.214->knockknock) (4.2.4)
Collecting cssutils
Downloading cssutils-2.4.0-py3-none-any.whl (404 kB)
|████████████████████████████████| 404 kB 68.7 MB/s
Collecting cssselect
Downloading cssselect-1.1.0-py2.py3-none-any.whl (16 kB)
Collecting cachetools
Downloading cachetools-4.2.2-py3-none-any.whl (11 kB)
Collecting tornado>=6.1
Downloading tornado-6.1-cp37-cp37m-manylinux2010_x86_64.whl (428 kB)
|████████████████████████████████| 428 kB 23.9 MB/s
Requirement already satisfied: pytz>=2018.6 in /usr/local/lib/python3.7/dist-packages (from python-telegram-bot->knockknock) (2022.1)
Collecting APScheduler==3.6.3
Downloading APScheduler-3.6.3-py2.py3-none-any.whl (58 kB)
|████████████████████████████████| 58 kB 7.3 MB/s
Requirement already satisfied: tzlocal>=1.2 in /usr/local/lib/python3.7/dist-packages (from APScheduler==3.6.3->python-telegram-bot->knockknock) (1.5.1)
Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from APScheduler==3.6.3->python-telegram-bot->knockknock) (1.15.0)
Requirement already satisfied: setuptools>=0.7 in /usr/local/lib/python3.7/dist-packages (from APScheduler==3.6.3->python-telegram-bot->knockknock) (57.4.0)
Collecting PyJWT<3.0.0,>=2.0.0
Downloading PyJWT-2.4.0-py3-none-any.whl (18 kB)
Installing collected packages: jeepney, cssutils, cssselect, cryptography, cachetools, tornado, SecretStorage, pyyaml, PyJWT, premailer, APScheduler, yagmail, twilio, tokenizers, python-telegram-bot, matrix-client, keyring, huggingface-hub, transformers, knockknock
Attempting uninstall: cachetools
Found existing installation: cachetools 4.2.4
Uninstalling cachetools-4.2.4:
Successfully uninstalled cachetools-4.2.4
Attempting uninstall: tornado
Found existing installation: tornado 5.1.1
Uninstalling tornado-5.1.1:
Successfully uninstalled tornado-5.1.1
Attempting uninstall: pyyaml
Found existing installation: PyYAML 3.13
Uninstalling PyYAML-3.13:
Successfully uninstalled PyYAML-3.13
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires tornado~=5.1.0; python_version >= "3.0", but you have tornado 6.1 which is incompatible.
Successfully installed APScheduler-3.6.3 PyJWT-2.4.0 SecretStorage-3.3.2 cachetools-4.2.2 cryptography-37.0.2 cssselect-1.1.0 cssutils-2.4.0 huggingface-hub-0.7.0 jeepney-0.8.0 keyring-23.5.1 knockknock-0.1.8.1 matrix-client-0.4.0 premailer-3.10.0 python-telegram-bot-13.12 pyyaml-6.0 tokenizers-0.12.1 tornado-6.1 transformers-4.19.2 twilio-7.9.1 yagmail-0.15.277
디스코드 서버와 연동하는 패키지를 다운로드 합니다. 사용법은 밑에서 설명합니다.
print(train.shape)
print(test.shape)
(17970, 3) (179700, 3)
tem = CountVectorizer()
tem.fit(train['code1'])
tem.vocabulary_
{'flag': 4281, 'go': 4603, 'cnt': 2991, 'while': 10034, 'int': 5201, 'input': 5137, 'if': 4958, 'stop': 9031, 'else': 3928, 'print': 7773, 'case': 2833, 'str': 9035, 'map': 6152, 'split': 8907, 'import': 5023, 'numpy': 7117, 'as': 2180, 'np': 7004, 'sys': 9222, 'read': 8108, 'stdin': 9007, 'buffer': 2639, 'readline': 8129, 'from': 4363, 'numba': 7090, 'njit': 6906, 'def': 3485, 'getinputs': 4546, 'cs': 3257, 'array': 2158, 'int32': 5203, '26': 725, 'reshape': 8212, 'return': 8247, 'i8': 4903, 'i4': 4898, 'cache': 2753, 'true': 9668, '_compute_score1': 1572, 'out': 7356, 'score': 8599, 'last': 5725, 'zeros': 10428, 'for': 4327, 'in': 5026, 'range': 8075, 'len': 5793, 'sum': 9115, 'step1': 9011, 'max_score': 6252, '10000000': 225, 'best_i': 2428, 'append': 2103, 'pop': 7622, 'output': 7362, 'join': 5476, 'astype': 2199, 'tolist': 9583, 'ans': 2035, 'min': 6387, 'abs': 1835, 'rev': 8254, 'false': 4161, 'list': 5884, 'mini': 6435, '10': 219, 'count': 3194, 'se': 8622, 'set': 8700, 'tot': 9595, 'add': 1875, 'yes': 10329, 'no': 6941, 'setrecursionlimit': 8719, 'week': 10019, 'sun': 9180, 'mon': 6544, 'tue': 9686, 'wed': 10018, 'thu': 9464, 'fri': 4357, 'sat': 8572, 'main': 6121, 'decode': 3478, '__name__': 1532, '__main__': 1529, '10000': 222, 'break': 2604, 'and': 2025, 'exit': 4053, 'power': 7673, 'mod': 6498, 'bi': 2449, 'format': 4331, '2進数': 846, 'res': 8203, 'reverse': 8258, 'temp1': 9363, 'temp2': 9364, 'temp': 9361, 'x_': 10138, 'a1': 1640, 'a2': 1643, 'a3': 1648, 'an': 2020, 'at': 2203, 'となるt': 10732, 'sのペアの数は': 9233, 'かつ': 10484, 'sのペアの数と等しく': 9232, 'sのペアの数と等しい': 9231, 'ms': 6577, 'ps': 7830, 'msc': 6579, 'psc': 7833, 'mi': 6368, 'pi': 7542, 'get': 4475, 'ii': 4973, 'li': 5827, 'd2': 3378, 'dict': 3572, 'dp0': 3737, 'dp1': 3738, 'dp2': 3739, 'dp3': 3740, 'max': 6202, 'math': 6185, 'string': 9059, 'itertools': 5377, 'fractions': 4351, 'heapq': 4778, 'collections': 3052, 're': 8105, 'bisect': 2487, 'random': 8064, 'time': 9473, 'copy': 3175, 'functools': 4395, 'deque': 3524, 'inf': 5079, '20': 635, '998244353': 1449, 'dr': 3774, 'dc': 3445, 'li_': 5830, 'lf': 5819, 'float': 4303, 'ls': 6024, 'dp': 3736, 'reduce': 8149, 'gcd': 4445, 'la': 5708, 'inv': 5238, 'pow': 7661, 'lcm': 5753, 'addmod': 1886, 'answer': 2075, 'sums': 9176, 'try': 9673, 'except': 4034, 'eoferror': 3976, 'a_mod': 1713, 'hn': 4825, 'strip': 9065, 'mapint': 6158, 'accumulate': 1855, 'open': 7313, 'a_acc': 1665, 'initial': 5107, 'min_diff': 6395, 'left': 5777, 'right': 8278, 'となるものがいくつあるか': 10734, 'subdp': 9086, 'target': 9318, 'h1': 4689, 'm1': 6078, 'h2': 4691, 'm2': 6079, '60': 1150, 'usr': 9819, 'bin': 2460, 'env': 3971, 'python3': 7876, 'chain': 2869, 'solve': 8852, 'current': 3317, '一度も通ったことがない': 11203, 'ステップ目に通った事を記録': 11090, 'loop_len': 5978, 'ループの長さ': 11191, 'rest': 8219, '残り長さ': 12102, '残り長さをループの余剰にする': 12103, 'tokens': 9581, 'line': 5857, 'next': 6834, 'type': 9718, 'counter': 3219, 'decimal': 3475, 'numbers': 7105, 'book': 2565, 'bit全探索': 2511, 'xls': 10199, 'cost': 3185, 'x1': 10120, 'y1': 10287, 'int64': 5204, 'u4': 9732, 'uint32': 9759, 'argv': 2149, 'online_judge': 7299, 'pycc': 7865, 'cc': 2843, 'types': 9720, 'my_module': 6626, 'export': 4061, 'factorization': 4145, 'n_max': 6703, 'sqrt': 8926, 'sort': 8866, 'p_max': 7437, 'primes_num': 7770, 'shape': 8734, 'a_start': 1732, 'check': 2895, 'stack': 8961, 'empty': 3932, 'p_stack': 7449, 'compile': 3123, 'in_file': 5031, 'fromstring': 4367, 'sep': 8690, 'pairwise': 7464, 'coprime': 3174, 'setwise': 8725, 'not': 6974, 'dev': 3536, 'product': 7822, 'repeat': 8191, 'lr': 6015, 'csum': 3267, 'encoding': 3937, 'utf': 9828, 'bisect_left': 2488, 'これで二部探索の大小検索が行える': 10580, '最小公倍数などはこっち': 11966, '10進数で考慮できる': 308, '再帰回数上限はでdefault1000': 11465, 'abssort': 1839, 'sorted': 8873, 'key': 5575, 'lambda': 5719, 'tmps': 9554, 'a_abs': 1664, 'tmp_1': 9504, 'deepcopy': 3483, 'tmp_2': 9507, 'tmp_2_': 9508, 'tmp_1_': 9505, 'を一つ消す': 10962, 'remove': 8176, 'を一つたす': 10961, 'がなかった時': 10508, 'を消して': 10994, 'を追加': 11011, 'tmp_1_m': 9506, 'tmp_2_m': 9509, 'pass': 7492, 'elif': 3924, '正負に関係なくsort': 12092, 'a_p': 1718, 'plusを入れる': 7584, 'a_n': 1714, 'を入れる': 10973, 'ok': 7267, '正の数が存在している': 12084, '選択肢がない時': 12440, '負の数が偶数個': 12367, '奇数個選ぶ': 11713, 'position': 7646, 'pairs': 7463, 'cnt_p': 3011, 'cnt_n': 3010, 'move': 6566, 'enumerate': 3969, '_update_score': 1625, '_random_update': 1610, 'randint': 8063, 'new_score': 6819, '_random_swap': 1609, 'delta': 3504, 'd1': 3375, 'or': 7325, 'step2': 9012, '48': 1011, 'rand': 8058, '13': 376, 'prime_numbers': 7765, 'n以下の素数列挙': 7202, 'eratosthenes': 3991, 'prime_list': 7761, 'prime_factorization': 7754, 'factors': 4152, 'tmp_n': 9529, 'ceil': 2856, 'count_divide': 3205, 'divmod': 3691, 'a_b_max': 1675, 'a_b_min': 1676, 'max_factors': 6224, 'min_factors': 6398, 'a_s': 1726, 'a_r': 1722, 'b_s': 2332, 'b_r': 2330, 'unsafe': 9783, 'safe': 8551, 'tmp': 9500, 'raw_input': 8087, 'linked': 5874, 'defaultdict': 3487, 'i2group': 4896, 'gid': 4583, 'get_root': 4514, 'get_groups': 4491, 'ra': 8040, 'rb': 8090, 'continue': 3162, 'n_connected': 6685, 'n_group': 6694, 'values': 9876, 'clear': 2951, 'zip': 10438, 'iter': 5371, 'class': 2948, 'facts': 4154, 'max_num': 6242, '__init__': 1525, 'self': 8682, 'fact': 4128, 'power_func': 7674, 'comb': 3077, 'log': 5954, 'r26': 7999, 'r25': 7998, '25': 717, 'total': 9601, 's_n': 8531, 'end': 3938, 'is': 5282, 'dice': 3558, 'ns': 7018, 'ew': 4020, 'question': 7973, 'top': 9587, 'front': 4368, 'settop': 8724, 'sides': 8770, 'index': 5060, 'tail': 9291, 'insert': 5182, 'dnum': 3711, 'readlines': 8130, 'lower': 6004, 'rstrip': 8427, 'word': 10080, 'abcdefghijklmnopqrstuvwxyz': 1823, 'insertionsort': 5187, 'lst': 6035, 'pajew': 7465, '1000000': 224, 'find': 4229, 'unite': 9778, 'grou': 4646, 'arr': 2155, 'gro': 4644, 'gro_no': 4645, 'popleft': 7635, 'distance': 3649, 'returns': 8252, 'minkowski': 6440, 'of': 7253, 'vactor': 9862, 'chebyshev': 2894, '6f': 1202, '000000': 3, '449490': 991, '154435': 437, 'run': 8437, 'dim': 3606, 'flake8': 4295, 'noqa': 6967, 'building_a': 2649, '11': 309, 'building_b': 2650, 'building_c': 2651, 'building_d': 2652, 'stdout': 9009, 'write': 10104, 'p_list': 7432, 'c_list': 2738, 'score1': 8600, 'score2': 8601, 'div': 3667, 'getn': 4554, 'getnm': 4555, 'getlist': 4552, 'getarray': 4534, 'intn': 5227, 'rand_n': 8061, 'ran1': 8056, 'ran2': 8057, 'rand_list': 8060, 'rantime': 8079, 'rand_ints_nodup': 8059, 'rand_query': 8062, 'r_query': 8036, 'n_q': 6709, 'combinations': 3089, 'permutations': 7529, 'operator': 7318, 'mul': 6589, 'bisect_right': 2492, '1000000000': 227, 'code': 3028, 'limit回までコストカットできる': 5852, 'knapsack_6': 5603, 'upper': 9796, 'limit': 5851, 'weight': 10021, 'value': 9869, 'ボーナスでコスト1にするのを使ったか': 11148, 'コストカットできる時': 11056, 'できない時': 10685, '1000': 221, 'back': 2344, 'namedtuple': 6725, 'uf': 9753, 'rank': 8078, 'size': 8801, 'root': 8359, 'same': 8565, 'friends': 4359, 'block': 2537, '153_b': 433, 'a0': 1631, '102': 264, 'log2': 5956, 'logn': 5966, 'db': 3438, 'dbs': 3441, 'now': 6981, 'dll': 3698, 'command': 3108, 'appendleft': 2107, 'delete': 3497, 'deletefirst': 3499, 'coding': 3033, 'sr': 8934, 'ir': 5276, '左からgreedyに': 11744, 'monsters': 6549, 'bomb': 2556, 'attack': 2219, 'cook': 3169, 'your': 10360, 'dish': 3632, 'here': 4790, '400': 956, '599': 1115, '600': 1151, '799': 1279, '800': 1312, '999': 1450, '1199': 338, '1200': 345, '1399': 389, '1400': 391, '1599': 444, '1600': 446, '1799': 480, '1800': 484, '1999': 520, 'merge': 6348, 'mid': 6376, 'global': 4594, 'n1': 6664, 'n2': 6669, 'mergesort': 6352, 'e_red_scarf': 3852, 'mask': 6166, '1e9': 539, 'coefs': 3036, '14': 390, '22': 690, '33': 882, '46': 999, '15': 421, 'hon': 4834, 'pon': 7617, 'bon': 2563, 'num': 7039, 'ca': 2752, 'val': 9863, 'items': 5370, 'sa': 8547, 'ng': 6878, '方針': 11892, '各文字列の出現回数を数え': 11628, '出現回数が最大なる文字列を昇順に出力する': 11485, 'リスト': 11168, 'は辞書型のサブクラスであり': 10900, 'キーに要素': 11033, '値に出現回数という形式': 11381, 'most_common': 6554, '要素': 12305, '出現回数': 11481, 'というタプルを出現回数順に並べたリスト': 10713, 'max_count': 6215, '最大の出現回数': 11945, '出現回数が最も多い単語を集計する': 11484, '昇順にソートして出力': 11903, 'resolve': 8216, '300000': 854, '200000': 638, '100000': 223, 'bubble_sort_aoj': 2629, 'nums': 7120, 'バブルソート': 11123, '隣接項の比較': 12482, 'fibo': 4199, 'result': 8225, 'groupby': 4659, '100': 220, '101': 258, '解説と': 12331, '13355391': 380, 'を参考に実装予定': 10981, 'lonlieness': 5971, 'ab': 1757, 'bad_a': 2347, 'bad_b': 2348, '_gcd': 1582, 'setdefault': 8717, '仲の悪いグループも登録しておく': 11334, 'pair': 7459, 'keys': 5584, '仲の悪いグループは隣り合っているので飛び石で計算': 11333, 'gourp1': 4617, 'から1匹以上選ぶパターン': 10488, 'group2': 4650, 'どちらからも選ばないパターン計算する': 10758, 'group1': 4649, 'は仲が悪いので同時に選ばれることはない': 10891, 'group_num': 4657, 'badgroup_num': 2350, '全員と仲が悪いイワシのパターンを足し': 11446, 'すべてのイワシを選ばないパターンを除外': 10608, 'mn': 6490, 'diff': 3589, '500': 1050, 'isupper': 5358, 'sum1': 9116, 'sum2': 9117, 'del': 3496, 'shellsort': 8744, '262913': 727, '65921': 1178, '16577': 456, '4193': 974, '1073': 286, '281': 740, '77': 1266, '23': 700, 'n以下が確定していて': 7198, '0以外の数をk個使ったとき': 211, 'n以下が確定していないときの0以外の数の個数': 7199, '0を使うことで0以外の数が増えないパターン': 204, '0以外の数を使うことで0以外の数が増えるパターン': 212, '今回でn以下が確定するパターン': 11301, '確定する前までに0以外の数を何個使っているか': 12177, '今回でn以下が確定することはない': 11300, 'すでにk個以上の0以外の数を使っているとき': 10602, 'ちょうどk個使っている時': 10670, '0を使うしかない': 205, 'n以下を確定させるためaは使えない': 7203, 'にぶたん': 10800, 'n人のメンバーそれぞれが完食にかかる時間のうち最大値をxに以下にできるか': 7195, '12': 343, 'need_training': 6782, 'cond': 3138, 'rotate': 8374, 'deck': 3477, 'nnnn': 6938, 'query': 7971, 'num_of_sug': 7072, 'sug': 9104, 'tuple': 9688, 'liar': 5834, 'honest': 4836, 'sug_tmp': 9106, 'sug_': 9105, 'hi': 4796, 'hihi': 4805, 'hihihi': 4806, 'hihihihi': 4807, 'hihihihihi': 4808, 'statistics': 8999, 'amed': 2015, 'median': 6328, 'bmed': 2547, 'sgn': 8731, 'pfugou': 7536, '選んでないのが2個以下': 12437, 'よってlen': 10939, '前述で処理済み': 11546, 'なのでここでやることはない': 10778, 'b_num': 2328, 'b_best': 2309, 'a_num': 1717, 'maxmize': 6282, 'none': 6964, 'all_max': 1977, 'st': 8953, 'scores': 8612, 'num_elem': 7058, 'all_sum': 1980, 'max_': 6205, 'max_r': 6249, 'temp_r': 9379, 'temp_max': 9377, 'chr': 2929, 'ord': 7327, 'head': 4767, 's_temp': 8544, 'dig_0_index': 3594, 'dig_1_index': 3595, 'dig_2_index': 3596, '_input': 1588, 'wa': 9980, 'ac': 1841, 'str_l': 9049, 'int_l': 5210, 'pp': 7680, 'seikai': 8665, 'matigai': 6186, 'correct': 3181, 'mistake': 6473, 'io': 5259, 'stringio': 9063, 'kuku': 5627, 's_set': 8541, 'ansl': 2068, 'mat_sum': 6177, 'xrange': 10224, 'color': 3056, 'dfs': 3540, 'mydict': 6640, 'answer_list': 2077, 'wd': 10011, 'ck': 2946, 'pe': 7513, 'penalty': 7521, 'cp': 3235, 'r_map': 8032, 'r_list': 8031, '最大公約数': 11957, '最小公倍数': 11965, 'gcd_num': 4452, 'lcm_num': 5762, 'die': 3583, 'pips': 7550, 'move_die': 6569, 'direction': 3617, 'get_upside': 4530, 'init_die': 5093, 'pip': 7549, 'roll_die': 8333, 'directions': 3618, 'maxs': 6287, 'mins': 6447, 'offset': 7257, '1000000007': 236, 'matrix': 6189, 'cv': 3342, 'fv': 4399, 'simu': 8784, 'pro': 7809, 'end_time': 3942, 'bool': 2569, 'name': 6724, 'inds': 5078, 'ds': 3779, 'xy': 10236, 'この2行でメモリアクセス省略しないとtleになる': 10573, 'nds': 6770, 'c1': 2717, 'c2': 2718, '164': 452, '1415926535898': 403, '08': 104, 'point': 7603, 'sharp': 8735, 'enumerate_divisors': 3970, 'all_divisors': 1972, 'divisor': 3685, 'calculate_reminder': 2779, 'reminder': 8174, 'sorted_lst': 8880, 'qs': 7951, 'rs': 8414, 'sect': 8640, 'lstrip': 6040, '水たまりを結合': 12110, 'sum_l': 9136, 'room': 8349, '_s': 1611, 'bitsum': 2507, '_bit': 1567, 'bitadd': 2501, 'al': 1940, 'al_to_idx': 1942, 'init': 5091, 'n_': 6678, 'bit': 2494, 'idx': 4946, '_query': 1607, 'decrement': 3482, 'old': 7277, 'increment': 5052, '_ans': 1564, 'money': 6546, 'inputs': 5175, 'ss': 8944, 'deg': 3490, '30': 850, '180': 483, '360': 908, 'radians': 8049, 'cos': 3182, 'sin': 8785, '110000': 312, 'kk': 5595, 'lu': 6049, '1e18': 537, '2019': 660, '相対速度': 12168, '距離': 12384, 'eval': 4011, 'train': 9638, '回数': 11669, 'nlogn': 6919, '時間': 11919, 'nlog': 6918, 'maxk': 6280, 'kaisuu': 5551, 'get_theta': 4523, 'm_angle': 6092, 'h_angle': 4696, 'calculate_vector_distance': 2781, 'theta': 9451, 'dictionary': 3579, 'input_num': 5160, 'lim': 5849, '200004': 645, 'bin_sum': 2465, 'bin_sum2': 2466, 'pop_num': 7627, '200005': 646, '整数': 11865, '整数複数個': 11871, '改行区切り': 11846, 'スペース区切り': 11092, 'の行列': 10869, 'abc': 1764, 'table': 9285, 'have': 4753, 'check_p': 2902, 'ws': 10108, 'ct': 3268, 'nt': 7028, 'can_eat': 2791, 'f_time': 4112, 'training': 9640, 'high': 4798, 'low': 6003, 'casefold': 2834, 'end_of_text': 3940, 'syo': 9218, 'amari': 2011, 'dic': 3553, 'kaisu': 5550, 'long': 5969, 'ae': 1903, 'bs_meguru': 2615, 'isok': 5345, 'to': 9563, '59': 1112, 'shell': 8742, 'leng': 5803, 'lists': 5910, 'eg': 3893, 'xs': 10225, 'seen': 8652, 'init_cmb': 5092, 'nmax': 6925, '出力の制限': 11478, 'g1': 4410, '元テーブル': 11397, 'g2': 4411, '逆元テーブル': 12404, 'inverse': 5247, '逆元テーブル計算用テーブル': 12405, 'cmb': 2972, 'modn': 6527, 'ci': 2932, 'である個数が': 10683, 'となるような数列の数は': 10735, 'ncm': 6752, 'hm': 4820, '1cn': 531, 'で足算する': 10707, 'wk': 10059, 'hw': 4864, 'cnt_h': 3008, 'cnt_w': 3015, 'max_h': 6226, 'max_w': 6262, 'h_list': 4704, 'w_list': 9972, 'find_primes': 4240, 'rn': 8324, 'prev': 7720, 'pos': 7639, 'alp': 1989, 'atoi': 2218, 'insort_left': 5191, 'tle': 9491, 'dtype': 3789, 'listをsortする': 5915, 'a0cen': 1636, 'b0cen': 2284, 'nn': 6932, 'a0cen1': 1637, 'b0cen1': 2285, 'ei': 3901, 'mx': 6610, 'su': 9077, 'graph': 4627, 'numofedges': 7114, 'visited': 9908, 'edges': 3883, 'col': 3041, 'adj': 1889, 'maxcolor': 6270, 'ans_1': 2040, 'ans_2': 2041, 'amax': 2014, 'n0': 6661, '1以上となる最小の2のべき乗数': 598, 'afre': 1907, 'パワーの頻度': 11130, 'は切り捨てなので': 10892, 'rintで四捨五入してから': 8298, 'rint': 8297, 'fft': 4191, 'irfft': 5278, 'rfft': 8267, 'scum': 8618, 'cumsum': 3290, '累積和': 12229, 'bd': 2393, '上からm個を取り出したい': 11223, 'searchsorted': 8629, '価値iを生み出せる組みがm個以上ある': 11367, '価値iが生み出せる選び方の余分なものを引きたい': 11366, 'ret': 8235, 'numberofcards': 7103, 'far': 4162, 'kyu': 5643, 'dist': 3638, 'vec': 9890, 'morau': 6551, 'factinv': 4136, 'solver': 8863, 'bombs': 2559, 'maxx': 6299, 'maxy': 6300, 'gcd1': 4446, 'cmath': 2971, 'inp': 5125, 'nm': 6923, 'heapify': 4773, 'heappop': 4775, 'heappush': 4776, '17': 464, 'day': 3432, 'data': 3419, '全探索なら': 11448, '4000': 957, 'bit全探索でok': 2512, '一文字ずつlistへ格納': 11206, 'most': 6553, '縦の全loop': 12261, 'aa': 1740, 'where': 10032, 'b1': 2286, 'b2': 2288, 'b3': 2289, 'alphabet': 1996, 'loop': 5972, 'i1': 4893, 'i2': 4895, '96': 1430, '1100': 311, 'get_dist': 4485, 'du': 3791, 'dv': 3799, 'fullmatch': 4390, 'exame': 4030, 'suma': 9165, '連想配列': 12413, '先頭からの番号': 11405, '余分な量': 11352, 'que': 7965, 'k番目以降は一番左のやつ消していく': 5677, 'cur': 3295, 'examf': 4031, 'si': 8765, 'fact_inv': 4131, 'getdivisor': 4539, 'sum_leaf': 9137, 'before_top': 2410, '29': 754, 'iim': 4980, 'p25': 7397, 'p26': 7398, 'p26inv': 7399, '576923081': 1106, 'elem': 3919, 'dq': 3770, 'order': 7330, 'deletelast': 3501, 'koch': 5609, 'start': 8972, '途中の頂点をa': 12408, 'cとする': 3366, 'rr': 8410, 'segmenttree': 8662, '非再帰': 12491, 'segment': 8661, 'tree': 9646, 'func': 4392, '配列の長さ': 12448, 'minだとrmqになる': 6463, '木の高さhとすると': 12002, '1までのノード数': 587, 'h段目のノードにアクセスするために使う': 4871, 'ノード': 11119, 'parent': 7475, 'child': 2914, '1とk': 567, 'bit_length': 2498, 'あたいの初期化': 10470, 'build': 2641, 'setの後に一斉更新': 8727, 'reversed': 8262, 'update': 9791, 'aに更新する': 2267, '更新ぶんをrootまで更新': 11927, 'のfuncを求める': 10818, 'queries': 7970, 'a2n': 1646, 'createinp': 3250, 'seg': 8658, 'terms': 9396, '51': 1073, 'goukei': 4616, 'route': 8386, 'obs': 7238, 'length': 5804, 'alpha2num': 1995, 'alpha': 1994, 'item': 5366, 'num2alpha': 7044, '64': 1169, '90': 1397, 'ap': 2098, 'bust': 2664, 'win': 10047, 'graph_input': 4630, 'friend': 4358, 'group': 4648, 'で頂点': 10709, 'がどの': 10507, 'に属するかを記録していく': 10809, '後に': 11782, 'に対して': 10807, ...}
CountVectorizer 함수로 입력되는 단어를 숫자와 매칭시킵니다.
tem.transform([train['code1'][0]]).toarray()
array([[0, 0, 0, ..., 0, 0, 0]])
tem.transform([train['code1'][0]]).toarray().shape
(1, 12552)
code1에 있는 단어 중 단어 집합에 해당하는 단어가 있는 경우에만 1을 출력하는 백터로 변환합니다.
cosine_similarity(tem.transform([train['code1'][0]]), tem.transform([train['code2'][0]]))
array([[0.32871913]])
train['similar'][0]
1
cosine_similarity 함수는 위에서 원-핫 인코딩 형태로 변환 된 단어들을 보고 유사성 여부를 판단합니다.
0.32로 유사성이 일부 있는 것으로 보이는데 실제 두 코드는 유사한 코드 입니다.
class BaselineModel():
def __init__(self, threshold = 0.5):
super(BaselineModel, self).__init__()
self.threshold = threshold
self.vectorizer = CountVectorizer()
def fit(self, code1, code2):
self.vectorizer.fit(code1)
self.vectorizer.fit(code2)
print('Done.')
def predict_proba(self, code1, code2):
code1_vecs = self.vectorizer.transform(code1)
code2_vecs = self.vectorizer.transform(code2)
preds = []
for code1_vec, code2_vec in zip(code1_vecs, code2_vecs):
preds.append(cosine_similarity(code1_vec, code2_vec))
preds = np.reshape(preds, len(preds))
print('Done.')
return preds
@discord_sender(webhook_url="https://discordapp.com/api/webhooks/9810o3fUYfVz2jWg7if")
def predict(self, code1, code2):
preds = self.predict_proba(code1, code2)
preds = np.where(preds > self.threshold, 1, 0)
return preds
트레인 데이터를 이용해 단어 집합을 만들고 테스트 데이터를 원-핫 인코딩 방식으로 변환합니다.
다음으로 cosine_similarity 함수를 사용해 원-핫 인코딩 벡터의 유사성을 검정해 임개값보다 크면 1, 작으면 0을 출력합니다.
딥러닝이라기 보다 단순한 컴퓨터 노가다에 가깝죠. 베이스라인으로 사용하기에 좋은 모델인 것 같습니다.
참고로 함수 위에 @discord_sender(webhook_url = '디스코드 웹 서버')를 입력하면 함수 시작하는 시점, 끝나는 시점이 디스코드 알림으로 옵니다.
학습하는데 오래걸리는 딥러닝 모델에 경우, 핸드폰으로 알림을 받을 수 있으니 알아두면 정말 좋은 기능이 될 것 같아요.
model = BaselineModel(threshold = 0.4)
model.fit(train['code1'], train['code2'])
Done.
preds = model.predict(test['code1'], test['code2'])
Done.
(train['similar']).mean()
0.5011129660545354
(preds).mean()
0.5092877017250974
threshold = 0.5일때 약 0.3, threshold = 0.3일때 약 0.7, threshold = 0.4일때 약 0.5를 가지는 것을 확인했습니다.
실제 트레인 데이터의 유사성이 1일 확률이 0.5에 가까움으로 임계값은 0.4로 잡겠습니다.
(대회는 ACCURACY 기준)
predtrain = model.predict(train['code1'], train['code2'])
(predtrain == train['similar']).mean()
Done.
0.7125765164162493
임계값 0.4를 사용해 트레인 데이터를 예측에 사용하면 약 0.71에 정확성을 보입니다.
사실 거창한 딥러닝을 사용하지 않아도 쓰는 단어가 얼추 비슷하면 유사성 판단은 어느정도는 하는 것을 알 수 있습니다.
sample_submission['similar'] = preds
sample_submission.to_csv('dacon_codes.csv', index = False)
# 결과 : 0.688
!pip install transformers
!pip install transformers datasets
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.19.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.7.0)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.12.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.3)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.19.2)
Collecting datasets
Downloading datasets-2.2.2-py3-none-any.whl (346 kB)
|████████████████████████████████| 346 kB 4.3 MB/s
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.3)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.7.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)
Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.12.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)
Collecting xxhash
Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
|████████████████████████████████| 212 kB 10.1 MB/s
Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)
Collecting fsspec[http]>=2021.05.0
Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
|████████████████████████████████| 140 kB 10.9 MB/s
Collecting aiohttp
Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
|████████████████████████████████| 1.1 MB 12.0 MB/s
Collecting dill<0.3.5
Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)
|████████████████████████████████| 86 kB 6.1 MB/s
Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)
Collecting responses<0.19
Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1
Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)
|████████████████████████████████| 127 kB 28.0 MB/s
Collecting aiosignal>=1.1.2
Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting asynctest==0.13.0
Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0)
Collecting yarl<2.0,>=1.0
Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)
|████████████████████████████████| 271 kB 28.3 MB/s
Collecting async-timeout<5.0,>=4.0.0a3
Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Collecting multidict<7.0,>=4.5
Downloading multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (94 kB)
|████████████████████████████████| 94 kB 1.8 MB/s
Collecting frozenlist>=1.1.1
Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)
|████████████████████████████████| 144 kB 29.7 MB/s
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)
Installing collected packages: multidict, frozenlist, yarl, urllib3, asynctest, async-timeout, aiosignal, fsspec, dill, aiohttp, xxhash, responses, datasets
Attempting uninstall: urllib3
Found existing installation: urllib3 1.24.3
Uninstalling urllib3-1.24.3:
Successfully uninstalled urllib3-1.24.3
Attempting uninstall: dill
Found existing installation: dill 0.3.5.1
Uninstalling dill-0.3.5.1:
Successfully uninstalled dill-0.3.5.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires tornado~=5.1.0; python_version >= "3.0", but you have tornado 6.1 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
Successfully installed aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 datasets-2.2.2 dill-0.3.4 frozenlist-1.3.0 fsspec-2022.5.0 multidict-6.0.2 responses-0.18.0 urllib3-1.25.11 xxhash-3.0.0 yarl-1.7.2
from transformers import AutoTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset, load_metric
import torch
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = "klue/bert-base" # 'microsoft/graphcodebert-base'
MAX_LEN = 256
dataset = load_dataset('csv', data_files = path+'sample_train.csv')['train']
tokenizer = AutoTokenizer.from_pretrained(model)
Using custom data configuration default-e6c40baceae51225 Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-e6c40baceae51225/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519) https://huggingface.co/klue/bert-base/resolve/main/tokenizer_config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpjiwfrzjb storing https://huggingface.co/klue/bert-base/resolve/main/tokenizer_config.json in cache at /root/.cache/huggingface/transformers/f8f71eb411bb03f57b455cfb1b4e04ae124201312e67a3ad66e0a92d0c228325.78871951edcb66032caa0a9628d77b3557c23616c653dacdb7a1a8f33011a843 creating metadata file for /root/.cache/huggingface/transformers/f8f71eb411bb03f57b455cfb1b4e04ae124201312e67a3ad66e0a92d0c228325.78871951edcb66032caa0a9628d77b3557c23616c653dacdb7a1a8f33011a843 https://huggingface.co/klue/bert-base/resolve/main/config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp3hhezuli storing https://huggingface.co/klue/bert-base/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621 creating metadata file for /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621 loading configuration file https://huggingface.co/klue/bert-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621 Model config BertConfig { "_name_or_path": "klue/bert-base", "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.19.2", "type_vocab_size": 2, "use_cache": true, "vocab_size": 32000 } https://huggingface.co/klue/bert-base/resolve/main/vocab.txt not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpf76t5zqa storing https://huggingface.co/klue/bert-base/resolve/main/vocab.txt in cache at /root/.cache/huggingface/transformers/1a36e69d48a008e522b75e43693002ffc8b6e6df72de7c53412c23466ec165eb.085110015ec67fc02ad067f712a7c83aafefaf31586a3361dd800bcac635b456 creating metadata file for /root/.cache/huggingface/transformers/1a36e69d48a008e522b75e43693002ffc8b6e6df72de7c53412c23466ec165eb.085110015ec67fc02ad067f712a7c83aafefaf31586a3361dd800bcac635b456 https://huggingface.co/klue/bert-base/resolve/main/tokenizer.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp5jkm4tpg storing https://huggingface.co/klue/bert-base/resolve/main/tokenizer.json in cache at /root/.cache/huggingface/transformers/310a974e892b181d75eed58b545cc0592d066ae4ef35cc760ea92e9b0bf65b3b.74f7933572f937b11a02b2cfb4e88a024059be36c84f53241b85b1fec49e21f7 creating metadata file for /root/.cache/huggingface/transformers/310a974e892b181d75eed58b545cc0592d066ae4ef35cc760ea92e9b0bf65b3b.74f7933572f937b11a02b2cfb4e88a024059be36c84f53241b85b1fec49e21f7 https://huggingface.co/klue/bert-base/resolve/main/special_tokens_map.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp34fn1p8b storing https://huggingface.co/klue/bert-base/resolve/main/special_tokens_map.json in cache at /root/.cache/huggingface/transformers/aeaaa3afd086a040be912f92ffe7b5f85008b744624f4517c4216bcc32b51cf0.054ece8d16bd524c8a00f0e8a976c00d5de22a755ffb79e353ee2954d9289e26 creating metadata file for /root/.cache/huggingface/transformers/aeaaa3afd086a040be912f92ffe7b5f85008b744624f4517c4216bcc32b51cf0.054ece8d16bd524c8a00f0e8a976c00d5de22a755ffb79e353ee2954d9289e26 loading file https://huggingface.co/klue/bert-base/resolve/main/vocab.txt from cache at /root/.cache/huggingface/transformers/1a36e69d48a008e522b75e43693002ffc8b6e6df72de7c53412c23466ec165eb.085110015ec67fc02ad067f712a7c83aafefaf31586a3361dd800bcac635b456 loading file https://huggingface.co/klue/bert-base/resolve/main/tokenizer.json from cache at /root/.cache/huggingface/transformers/310a974e892b181d75eed58b545cc0592d066ae4ef35cc760ea92e9b0bf65b3b.74f7933572f937b11a02b2cfb4e88a024059be36c84f53241b85b1fec49e21f7 loading file https://huggingface.co/klue/bert-base/resolve/main/added_tokens.json from cache at None loading file https://huggingface.co/klue/bert-base/resolve/main/special_tokens_map.json from cache at /root/.cache/huggingface/transformers/aeaaa3afd086a040be912f92ffe7b5f85008b744624f4517c4216bcc32b51cf0.054ece8d16bd524c8a00f0e8a976c00d5de22a755ffb79e353ee2954d9289e26 loading file https://huggingface.co/klue/bert-base/resolve/main/tokenizer_config.json from cache at /root/.cache/huggingface/transformers/f8f71eb411bb03f57b455cfb1b4e04ae124201312e67a3ad66e0a92d0c228325.78871951edcb66032caa0a9628d77b3557c23616c653dacdb7a1a8f33011a843 loading configuration file https://huggingface.co/klue/bert-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621 Model config BertConfig { "_name_or_path": "klue/bert-base", "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.19.2", "type_vocab_size": 2, "use_cache": true, "vocab_size": 32000 }
load_dataset 함수 : csv 파일을 가공하기 쉬운 데이터 셋으로 만들어줍니다.
AutoTokenizer내 from_pretrained 함수에서 프리트레인 모델 이름만 입력하면 자동으로 토크나이징이 됩니다.
def example_fn(examples):
outputs = tokenizer(examples['code1'], examples['code2'], padding = True, max_length = MAX_LEN, truncation = True)
if 'similar' in examples:
outputs['labels'] = examples['similar']
return outputs
dataset = dataset.map(example_fn, remove_columns = ['code1', 'code2', 'similar'])
dataset = dataset.train_test_split(0.1)
dataset
DatasetDict({ train: Dataset({ features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'], num_rows: 16173 }) test: Dataset({ features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'], num_rows: 1797 }) })
dataset 내 map 함수 : 원소별로 입력된 함수를 적용합니다.
dataset 내 train_test_split 함수 : 트레인-테스트 데이터 셋으로 분할한 딕셔너리를 만듭니다.
_collator = DataCollatorWithPadding(tokenizer = tokenizer) # 아래 사진으로 해당함수 설명
_metric = load_metric('glue', 'sst2') # 측정함수도 huggingface 내 존재
# https://huggingface.co/docs/datasets/v1.0.1/loading_metrics.html 참고문서
def metric_fn(p): # 측정함수
preds, labels = p
output = _metric.compute(references = labels, predictions = np.argmax(preds, axis = -1))
return output
model = BertForSequenceClassification.from_pretrained(model)
args = TrainingArguments(
'runs/',
per_device_train_batch_size = 32,
num_train_epochs = 3,
do_train = True,
do_eval = True,
save_strategy = 'epoch',
logging_strategy = 'epoch',
evaluation_strategy = 'epoch',
)
loading configuration file https://huggingface.co/klue/bert-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621 Model config BertConfig { "architectures": [ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.19.2", "type_vocab_size": 2, "use_cache": true, "vocab_size": 32000 } https://huggingface.co/klue/bert-base/resolve/main/pytorch_model.bin not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpb8lhvjfk storing https://huggingface.co/klue/bert-base/resolve/main/pytorch_model.bin in cache at /root/.cache/huggingface/transformers/05b36ee62545d769939a7746eca739b844a40a7a7553700f110b58b28ed6a949.7cb231256a5dbe886e12b902d05cb1241f330d8c19428508f91b2b28c1cfe0b6 creating metadata file for /root/.cache/huggingface/transformers/05b36ee62545d769939a7746eca739b844a40a7a7553700f110b58b28ed6a949.7cb231256a5dbe886e12b902d05cb1241f330d8c19428508f91b2b28c1cfe0b6 loading weights file https://huggingface.co/klue/bert-base/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/05b36ee62545d769939a7746eca739b844a40a7a7553700f110b58b28ed6a949.7cb231256a5dbe886e12b902d05cb1241f330d8c19428508f91b2b28c1cfe0b6 Some weights of the model checkpoint at klue/bert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight'] - This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of BertForSequenceClassification were not initialized from the model checkpoint at klue/bert-base and are newly initialized: ['classifier.bias', 'classifier.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. PyTorch: setting up devices The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
trainer = Trainer(
model = model,
args = args,
data_collator = _collator,
train_dataset = dataset['train'],
eval_dataset = dataset['test'],
tokenizer = tokenizer,
compute_metrics = metric_fn
)
@discord_sender(webhook_url="https://discordapp.com/api/webhooks/98101yo3fUYfVz2jWg7if")
def tem():
trainer.train()
tem()
***** Running training ***** Num examples = 16173 Num Epochs = 3 Instantaneous batch size per device = 32 Total train batch size (w. parallel, distributed & accumulation) = 32 Gradient Accumulation steps = 1 Total optimization steps = 1518
Epoch | Training Loss | Validation Loss | Accuracy |
---|---|---|---|
1 | 0.304100 | 0.188209 | 0.920423 |
2 | 0.117700 | 0.130098 | 0.960490 |
3 | 0.036200 | 0.093612 | 0.978297 |
</div> </div>
***** Running Evaluation ***** Num examples = 1797 Batch size = 8 Saving model checkpoint to runs/checkpoint-506 Configuration saved in runs/checkpoint-506/config.json Model weights saved in runs/checkpoint-506/pytorch_model.bin tokenizer config file saved in runs/checkpoint-506/tokenizer_config.json Special tokens file saved in runs/checkpoint-506/special_tokens_map.json ***** Running Evaluation ***** Num examples = 1797 Batch size = 8 Saving model checkpoint to runs/checkpoint-1012 Configuration saved in runs/checkpoint-1012/config.json Model weights saved in runs/checkpoint-1012/pytorch_model.bin tokenizer config file saved in runs/checkpoint-1012/tokenizer_config.json Special tokens file saved in runs/checkpoint-1012/special_tokens_map.json ***** Running Evaluation ***** Num examples = 1797 Batch size = 8 Saving model checkpoint to runs/checkpoint-1518 Configuration saved in runs/checkpoint-1518/config.json Model weights saved in runs/checkpoint-1518/pytorch_model.bin tokenizer config file saved in runs/checkpoint-1518/tokenizer_config.json Special tokens file saved in runs/checkpoint-1518/special_tokens_map.json Training completed. Do not forget to share your model on huggingface.co/models =)
test_dataset = load_dataset('csv', data_files = path+'test.csv')['train']
test_dataset = test_dataset.map(example_fn, remove_columns = ['code1', 'code2'])
predictions = trainer.predict(test_dataset)
sample_submission['similar'] = np.argmax(predictions.predictions, axis = -1)
sample_submission.to_csv('dacon_codes2.csv', index = False)
# 결과 : 0.787
Using custom data configuration default-6692cc772abf77e3 Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-6692cc772abf77e3/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519) The following columns in the test set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: pair_id. If pair_id are not expected by `BertForSequenceClassification.forward`, you can safely ignore this message. ***** Running Prediction ***** Num examples = 179700 Batch size = 8
predictions.predictions
array([[ 4.3174033, -3.926483 ], [-4.4421177, 4.0519753], [-4.2312655, 3.7239847], ..., [ 2.8522801, -2.7274592], [-4.4676304, 4.039054 ], [ 2.7921119, -2.535519 ]], dtype=float32)
sample_submission['similar'].mean()
0.6730996104618809