[SSUDA] 트랜스포머 실습하기 with 데이콘 코드 유사성 대회
• Seong Yeon Kim • 26 min read
SSUDA transformer Deep Learning natural language DACON
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
구글 드라이브와 연동해 파일을 불러옵니다.
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import warnings
warnings.filterwarnings("ignore")
path = '/content/drive/MyDrive/coding/'
train = pd.read_csv(path + 'sample_train.csv')
test = pd.read_csv(path + 'test.csv')
sample_submission = pd.read_csv(path + 'sample_submission.csv')
train.head()
| code1 | code2 | similar | |
|---|---|---|---|
| 0 | flag = "go"\ncnt = 0\nwhile flag == "go":\n ... | # Python 3+\n#--------------------------------... | 1 |
| 1 | b, c = map(int, input().split())\n\nprint(b * c) | import numpy as np\n\nn = int(input())\na = np... | 0 |
| 2 | import numpy as np\nimport sys\nread = sys.std... | N, M = map(int, input().split())\nif M%2 != 0:... | 0 |
| 3 | b, c = map(int, input().split())\n\nprint(b * c) | n,m=map(int,input().split())\nh=list(map(int,i... | 0 |
| 4 | s=input()\nt=input()\nans=0\nfor i in range(le... | import math\na,b,h,m=map(int,input().split())\... | 0 |
필수 패키지를 불러오고 데이터를 불러옵니다.
!pip install transformers knockknock
from knockknock import discord_sender
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
|████████████████████████████████| 4.2 MB 5.3 MB/s
Collecting knockknock
Downloading knockknock-0.1.8.1-py3-none-any.whl (28 kB)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.3)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
|████████████████████████████████| 6.6 MB 48.3 MB/s
Collecting pyyaml>=5.1
Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
|████████████████████████████████| 596 kB 47.0 MB/s
Collecting huggingface-hub<1.0,>=0.1.0
Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
|████████████████████████████████| 86 kB 6.1 MB/s
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Collecting keyring
Downloading keyring-23.5.1-py3-none-any.whl (33 kB)
Collecting yagmail>=0.11.214
Downloading yagmail-0.15.277-py2.py3-none-any.whl (17 kB)
Collecting twilio
Downloading twilio-7.9.1-py2.py3-none-any.whl (1.4 MB)
|████████████████████████████████| 1.4 MB 45.6 MB/s
Collecting matrix-client
Downloading matrix_client-0.4.0-py2.py3-none-any.whl (43 kB)
|████████████████████████████████| 43 kB 2.4 MB/s
Collecting python-telegram-bot
Downloading python_telegram_bot-13.12-py3-none-any.whl (511 kB)
|████████████████████████████████| 511 kB 70.2 MB/s
Collecting premailer
Downloading premailer-3.10.0-py2.py3-none-any.whl (19 kB)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)
Collecting jeepney>=0.4.2
Downloading jeepney-0.8.0-py3-none-any.whl (48 kB)
|████████████████████████████████| 48 kB 6.2 MB/s
Collecting SecretStorage>=3.2
Downloading SecretStorage-3.3.2-py3-none-any.whl (15 kB)
Collecting cryptography>=2.0
Downloading cryptography-37.0.2-cp36-abi3-manylinux_2_24_x86_64.whl (4.0 MB)
|████████████████████████████████| 4.0 MB 45.9 MB/s
Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.7/dist-packages (from cryptography>=2.0->SecretStorage>=3.2->keyring->knockknock) (1.15.0)
Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.12->cryptography>=2.0->SecretStorage>=3.2->keyring->knockknock) (2.21)
Requirement already satisfied: urllib3~=1.21 in /usr/local/lib/python3.7/dist-packages (from matrix-client->knockknock) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: lxml in /usr/local/lib/python3.7/dist-packages (from premailer->yagmail>=0.11.214->knockknock) (4.2.6)
Requirement already satisfied: cachetools in /usr/local/lib/python3.7/dist-packages (from premailer->yagmail>=0.11.214->knockknock) (4.2.4)
Collecting cssutils
Downloading cssutils-2.4.0-py3-none-any.whl (404 kB)
|████████████████████████████████| 404 kB 68.7 MB/s
Collecting cssselect
Downloading cssselect-1.1.0-py2.py3-none-any.whl (16 kB)
Collecting cachetools
Downloading cachetools-4.2.2-py3-none-any.whl (11 kB)
Collecting tornado>=6.1
Downloading tornado-6.1-cp37-cp37m-manylinux2010_x86_64.whl (428 kB)
|████████████████████████████████| 428 kB 23.9 MB/s
Requirement already satisfied: pytz>=2018.6 in /usr/local/lib/python3.7/dist-packages (from python-telegram-bot->knockknock) (2022.1)
Collecting APScheduler==3.6.3
Downloading APScheduler-3.6.3-py2.py3-none-any.whl (58 kB)
|████████████████████████████████| 58 kB 7.3 MB/s
Requirement already satisfied: tzlocal>=1.2 in /usr/local/lib/python3.7/dist-packages (from APScheduler==3.6.3->python-telegram-bot->knockknock) (1.5.1)
Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from APScheduler==3.6.3->python-telegram-bot->knockknock) (1.15.0)
Requirement already satisfied: setuptools>=0.7 in /usr/local/lib/python3.7/dist-packages (from APScheduler==3.6.3->python-telegram-bot->knockknock) (57.4.0)
Collecting PyJWT<3.0.0,>=2.0.0
Downloading PyJWT-2.4.0-py3-none-any.whl (18 kB)
Installing collected packages: jeepney, cssutils, cssselect, cryptography, cachetools, tornado, SecretStorage, pyyaml, PyJWT, premailer, APScheduler, yagmail, twilio, tokenizers, python-telegram-bot, matrix-client, keyring, huggingface-hub, transformers, knockknock
Attempting uninstall: cachetools
Found existing installation: cachetools 4.2.4
Uninstalling cachetools-4.2.4:
Successfully uninstalled cachetools-4.2.4
Attempting uninstall: tornado
Found existing installation: tornado 5.1.1
Uninstalling tornado-5.1.1:
Successfully uninstalled tornado-5.1.1
Attempting uninstall: pyyaml
Found existing installation: PyYAML 3.13
Uninstalling PyYAML-3.13:
Successfully uninstalled PyYAML-3.13
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires tornado~=5.1.0; python_version >= "3.0", but you have tornado 6.1 which is incompatible.
Successfully installed APScheduler-3.6.3 PyJWT-2.4.0 SecretStorage-3.3.2 cachetools-4.2.2 cryptography-37.0.2 cssselect-1.1.0 cssutils-2.4.0 huggingface-hub-0.7.0 jeepney-0.8.0 keyring-23.5.1 knockknock-0.1.8.1 matrix-client-0.4.0 premailer-3.10.0 python-telegram-bot-13.12 pyyaml-6.0 tokenizers-0.12.1 tornado-6.1 transformers-4.19.2 twilio-7.9.1 yagmail-0.15.277
디스코드 서버와 연동하는 패키지를 다운로드 합니다. 사용법은 밑에서 설명합니다.
print(train.shape)
print(test.shape)
(17970, 3) (179700, 3)
tem = CountVectorizer()
tem.fit(train['code1'])
tem.vocabulary_
{'flag': 4281,
'go': 4603,
'cnt': 2991,
'while': 10034,
'int': 5201,
'input': 5137,
'if': 4958,
'stop': 9031,
'else': 3928,
'print': 7773,
'case': 2833,
'str': 9035,
'map': 6152,
'split': 8907,
'import': 5023,
'numpy': 7117,
'as': 2180,
'np': 7004,
'sys': 9222,
'read': 8108,
'stdin': 9007,
'buffer': 2639,
'readline': 8129,
'from': 4363,
'numba': 7090,
'njit': 6906,
'def': 3485,
'getinputs': 4546,
'cs': 3257,
'array': 2158,
'int32': 5203,
'26': 725,
'reshape': 8212,
'return': 8247,
'i8': 4903,
'i4': 4898,
'cache': 2753,
'true': 9668,
'_compute_score1': 1572,
'out': 7356,
'score': 8599,
'last': 5725,
'zeros': 10428,
'for': 4327,
'in': 5026,
'range': 8075,
'len': 5793,
'sum': 9115,
'step1': 9011,
'max_score': 6252,
'10000000': 225,
'best_i': 2428,
'append': 2103,
'pop': 7622,
'output': 7362,
'join': 5476,
'astype': 2199,
'tolist': 9583,
'ans': 2035,
'min': 6387,
'abs': 1835,
'rev': 8254,
'false': 4161,
'list': 5884,
'mini': 6435,
'10': 219,
'count': 3194,
'se': 8622,
'set': 8700,
'tot': 9595,
'add': 1875,
'yes': 10329,
'no': 6941,
'setrecursionlimit': 8719,
'week': 10019,
'sun': 9180,
'mon': 6544,
'tue': 9686,
'wed': 10018,
'thu': 9464,
'fri': 4357,
'sat': 8572,
'main': 6121,
'decode': 3478,
'__name__': 1532,
'__main__': 1529,
'10000': 222,
'break': 2604,
'and': 2025,
'exit': 4053,
'power': 7673,
'mod': 6498,
'bi': 2449,
'format': 4331,
'2進数': 846,
'res': 8203,
'reverse': 8258,
'temp1': 9363,
'temp2': 9364,
'temp': 9361,
'x_': 10138,
'a1': 1640,
'a2': 1643,
'a3': 1648,
'an': 2020,
'at': 2203,
'となるt': 10732,
'sのペアの数は': 9233,
'かつ': 10484,
'sのペアの数と等しく': 9232,
'sのペアの数と等しい': 9231,
'ms': 6577,
'ps': 7830,
'msc': 6579,
'psc': 7833,
'mi': 6368,
'pi': 7542,
'get': 4475,
'ii': 4973,
'li': 5827,
'd2': 3378,
'dict': 3572,
'dp0': 3737,
'dp1': 3738,
'dp2': 3739,
'dp3': 3740,
'max': 6202,
'math': 6185,
'string': 9059,
'itertools': 5377,
'fractions': 4351,
'heapq': 4778,
'collections': 3052,
're': 8105,
'bisect': 2487,
'random': 8064,
'time': 9473,
'copy': 3175,
'functools': 4395,
'deque': 3524,
'inf': 5079,
'20': 635,
'998244353': 1449,
'dr': 3774,
'dc': 3445,
'li_': 5830,
'lf': 5819,
'float': 4303,
'ls': 6024,
'dp': 3736,
'reduce': 8149,
'gcd': 4445,
'la': 5708,
'inv': 5238,
'pow': 7661,
'lcm': 5753,
'addmod': 1886,
'answer': 2075,
'sums': 9176,
'try': 9673,
'except': 4034,
'eoferror': 3976,
'a_mod': 1713,
'hn': 4825,
'strip': 9065,
'mapint': 6158,
'accumulate': 1855,
'open': 7313,
'a_acc': 1665,
'initial': 5107,
'min_diff': 6395,
'left': 5777,
'right': 8278,
'となるものがいくつあるか': 10734,
'subdp': 9086,
'target': 9318,
'h1': 4689,
'm1': 6078,
'h2': 4691,
'm2': 6079,
'60': 1150,
'usr': 9819,
'bin': 2460,
'env': 3971,
'python3': 7876,
'chain': 2869,
'solve': 8852,
'current': 3317,
'一度も通ったことがない': 11203,
'ステップ目に通った事を記録': 11090,
'loop_len': 5978,
'ループの長さ': 11191,
'rest': 8219,
'残り長さ': 12102,
'残り長さをループの余剰にする': 12103,
'tokens': 9581,
'line': 5857,
'next': 6834,
'type': 9718,
'counter': 3219,
'decimal': 3475,
'numbers': 7105,
'book': 2565,
'bit全探索': 2511,
'xls': 10199,
'cost': 3185,
'x1': 10120,
'y1': 10287,
'int64': 5204,
'u4': 9732,
'uint32': 9759,
'argv': 2149,
'online_judge': 7299,
'pycc': 7865,
'cc': 2843,
'types': 9720,
'my_module': 6626,
'export': 4061,
'factorization': 4145,
'n_max': 6703,
'sqrt': 8926,
'sort': 8866,
'p_max': 7437,
'primes_num': 7770,
'shape': 8734,
'a_start': 1732,
'check': 2895,
'stack': 8961,
'empty': 3932,
'p_stack': 7449,
'compile': 3123,
'in_file': 5031,
'fromstring': 4367,
'sep': 8690,
'pairwise': 7464,
'coprime': 3174,
'setwise': 8725,
'not': 6974,
'dev': 3536,
'product': 7822,
'repeat': 8191,
'lr': 6015,
'csum': 3267,
'encoding': 3937,
'utf': 9828,
'bisect_left': 2488,
'これで二部探索の大小検索が行える': 10580,
'最小公倍数などはこっち': 11966,
'10進数で考慮できる': 308,
'再帰回数上限はでdefault1000': 11465,
'abssort': 1839,
'sorted': 8873,
'key': 5575,
'lambda': 5719,
'tmps': 9554,
'a_abs': 1664,
'tmp_1': 9504,
'deepcopy': 3483,
'tmp_2': 9507,
'tmp_2_': 9508,
'tmp_1_': 9505,
'を一つ消す': 10962,
'remove': 8176,
'を一つたす': 10961,
'がなかった時': 10508,
'を消して': 10994,
'を追加': 11011,
'tmp_1_m': 9506,
'tmp_2_m': 9509,
'pass': 7492,
'elif': 3924,
'正負に関係なくsort': 12092,
'a_p': 1718,
'plusを入れる': 7584,
'a_n': 1714,
'を入れる': 10973,
'ok': 7267,
'正の数が存在している': 12084,
'選択肢がない時': 12440,
'負の数が偶数個': 12367,
'奇数個選ぶ': 11713,
'position': 7646,
'pairs': 7463,
'cnt_p': 3011,
'cnt_n': 3010,
'move': 6566,
'enumerate': 3969,
'_update_score': 1625,
'_random_update': 1610,
'randint': 8063,
'new_score': 6819,
'_random_swap': 1609,
'delta': 3504,
'd1': 3375,
'or': 7325,
'step2': 9012,
'48': 1011,
'rand': 8058,
'13': 376,
'prime_numbers': 7765,
'n以下の素数列挙': 7202,
'eratosthenes': 3991,
'prime_list': 7761,
'prime_factorization': 7754,
'factors': 4152,
'tmp_n': 9529,
'ceil': 2856,
'count_divide': 3205,
'divmod': 3691,
'a_b_max': 1675,
'a_b_min': 1676,
'max_factors': 6224,
'min_factors': 6398,
'a_s': 1726,
'a_r': 1722,
'b_s': 2332,
'b_r': 2330,
'unsafe': 9783,
'safe': 8551,
'tmp': 9500,
'raw_input': 8087,
'linked': 5874,
'defaultdict': 3487,
'i2group': 4896,
'gid': 4583,
'get_root': 4514,
'get_groups': 4491,
'ra': 8040,
'rb': 8090,
'continue': 3162,
'n_connected': 6685,
'n_group': 6694,
'values': 9876,
'clear': 2951,
'zip': 10438,
'iter': 5371,
'class': 2948,
'facts': 4154,
'max_num': 6242,
'__init__': 1525,
'self': 8682,
'fact': 4128,
'power_func': 7674,
'comb': 3077,
'log': 5954,
'r26': 7999,
'r25': 7998,
'25': 717,
'total': 9601,
's_n': 8531,
'end': 3938,
'is': 5282,
'dice': 3558,
'ns': 7018,
'ew': 4020,
'question': 7973,
'top': 9587,
'front': 4368,
'settop': 8724,
'sides': 8770,
'index': 5060,
'tail': 9291,
'insert': 5182,
'dnum': 3711,
'readlines': 8130,
'lower': 6004,
'rstrip': 8427,
'word': 10080,
'abcdefghijklmnopqrstuvwxyz': 1823,
'insertionsort': 5187,
'lst': 6035,
'pajew': 7465,
'1000000': 224,
'find': 4229,
'unite': 9778,
'grou': 4646,
'arr': 2155,
'gro': 4644,
'gro_no': 4645,
'popleft': 7635,
'distance': 3649,
'returns': 8252,
'minkowski': 6440,
'of': 7253,
'vactor': 9862,
'chebyshev': 2894,
'6f': 1202,
'000000': 3,
'449490': 991,
'154435': 437,
'run': 8437,
'dim': 3606,
'flake8': 4295,
'noqa': 6967,
'building_a': 2649,
'11': 309,
'building_b': 2650,
'building_c': 2651,
'building_d': 2652,
'stdout': 9009,
'write': 10104,
'p_list': 7432,
'c_list': 2738,
'score1': 8600,
'score2': 8601,
'div': 3667,
'getn': 4554,
'getnm': 4555,
'getlist': 4552,
'getarray': 4534,
'intn': 5227,
'rand_n': 8061,
'ran1': 8056,
'ran2': 8057,
'rand_list': 8060,
'rantime': 8079,
'rand_ints_nodup': 8059,
'rand_query': 8062,
'r_query': 8036,
'n_q': 6709,
'combinations': 3089,
'permutations': 7529,
'operator': 7318,
'mul': 6589,
'bisect_right': 2492,
'1000000000': 227,
'code': 3028,
'limit回までコストカットできる': 5852,
'knapsack_6': 5603,
'upper': 9796,
'limit': 5851,
'weight': 10021,
'value': 9869,
'ボーナスでコスト1にするのを使ったか': 11148,
'コストカットできる時': 11056,
'できない時': 10685,
'1000': 221,
'back': 2344,
'namedtuple': 6725,
'uf': 9753,
'rank': 8078,
'size': 8801,
'root': 8359,
'same': 8565,
'friends': 4359,
'block': 2537,
'153_b': 433,
'a0': 1631,
'102': 264,
'log2': 5956,
'logn': 5966,
'db': 3438,
'dbs': 3441,
'now': 6981,
'dll': 3698,
'command': 3108,
'appendleft': 2107,
'delete': 3497,
'deletefirst': 3499,
'coding': 3033,
'sr': 8934,
'ir': 5276,
'左からgreedyに': 11744,
'monsters': 6549,
'bomb': 2556,
'attack': 2219,
'cook': 3169,
'your': 10360,
'dish': 3632,
'here': 4790,
'400': 956,
'599': 1115,
'600': 1151,
'799': 1279,
'800': 1312,
'999': 1450,
'1199': 338,
'1200': 345,
'1399': 389,
'1400': 391,
'1599': 444,
'1600': 446,
'1799': 480,
'1800': 484,
'1999': 520,
'merge': 6348,
'mid': 6376,
'global': 4594,
'n1': 6664,
'n2': 6669,
'mergesort': 6352,
'e_red_scarf': 3852,
'mask': 6166,
'1e9': 539,
'coefs': 3036,
'14': 390,
'22': 690,
'33': 882,
'46': 999,
'15': 421,
'hon': 4834,
'pon': 7617,
'bon': 2563,
'num': 7039,
'ca': 2752,
'val': 9863,
'items': 5370,
'sa': 8547,
'ng': 6878,
'方針': 11892,
'各文字列の出現回数を数え': 11628,
'出現回数が最大なる文字列を昇順に出力する': 11485,
'リスト': 11168,
'は辞書型のサブクラスであり': 10900,
'キーに要素': 11033,
'値に出現回数という形式': 11381,
'most_common': 6554,
'要素': 12305,
'出現回数': 11481,
'というタプルを出現回数順に並べたリスト': 10713,
'max_count': 6215,
'最大の出現回数': 11945,
'出現回数が最も多い単語を集計する': 11484,
'昇順にソートして出力': 11903,
'resolve': 8216,
'300000': 854,
'200000': 638,
'100000': 223,
'bubble_sort_aoj': 2629,
'nums': 7120,
'バブルソート': 11123,
'隣接項の比較': 12482,
'fibo': 4199,
'result': 8225,
'groupby': 4659,
'100': 220,
'101': 258,
'解説と': 12331,
'13355391': 380,
'を参考に実装予定': 10981,
'lonlieness': 5971,
'ab': 1757,
'bad_a': 2347,
'bad_b': 2348,
'_gcd': 1582,
'setdefault': 8717,
'仲の悪いグループも登録しておく': 11334,
'pair': 7459,
'keys': 5584,
'仲の悪いグループは隣り合っているので飛び石で計算': 11333,
'gourp1': 4617,
'から1匹以上選ぶパターン': 10488,
'group2': 4650,
'どちらからも選ばないパターン計算する': 10758,
'group1': 4649,
'は仲が悪いので同時に選ばれることはない': 10891,
'group_num': 4657,
'badgroup_num': 2350,
'全員と仲が悪いイワシのパターンを足し': 11446,
'すべてのイワシを選ばないパターンを除外': 10608,
'mn': 6490,
'diff': 3589,
'500': 1050,
'isupper': 5358,
'sum1': 9116,
'sum2': 9117,
'del': 3496,
'shellsort': 8744,
'262913': 727,
'65921': 1178,
'16577': 456,
'4193': 974,
'1073': 286,
'281': 740,
'77': 1266,
'23': 700,
'n以下が確定していて': 7198,
'0以外の数をk個使ったとき': 211,
'n以下が確定していないときの0以外の数の個数': 7199,
'0を使うことで0以外の数が増えないパターン': 204,
'0以外の数を使うことで0以外の数が増えるパターン': 212,
'今回でn以下が確定するパターン': 11301,
'確定する前までに0以外の数を何個使っているか': 12177,
'今回でn以下が確定することはない': 11300,
'すでにk個以上の0以外の数を使っているとき': 10602,
'ちょうどk個使っている時': 10670,
'0を使うしかない': 205,
'n以下を確定させるためaは使えない': 7203,
'にぶたん': 10800,
'n人のメンバーそれぞれが完食にかかる時間のうち最大値をxに以下にできるか': 7195,
'12': 343,
'need_training': 6782,
'cond': 3138,
'rotate': 8374,
'deck': 3477,
'nnnn': 6938,
'query': 7971,
'num_of_sug': 7072,
'sug': 9104,
'tuple': 9688,
'liar': 5834,
'honest': 4836,
'sug_tmp': 9106,
'sug_': 9105,
'hi': 4796,
'hihi': 4805,
'hihihi': 4806,
'hihihihi': 4807,
'hihihihihi': 4808,
'statistics': 8999,
'amed': 2015,
'median': 6328,
'bmed': 2547,
'sgn': 8731,
'pfugou': 7536,
'選んでないのが2個以下': 12437,
'よってlen': 10939,
'前述で処理済み': 11546,
'なのでここでやることはない': 10778,
'b_num': 2328,
'b_best': 2309,
'a_num': 1717,
'maxmize': 6282,
'none': 6964,
'all_max': 1977,
'st': 8953,
'scores': 8612,
'num_elem': 7058,
'all_sum': 1980,
'max_': 6205,
'max_r': 6249,
'temp_r': 9379,
'temp_max': 9377,
'chr': 2929,
'ord': 7327,
'head': 4767,
's_temp': 8544,
'dig_0_index': 3594,
'dig_1_index': 3595,
'dig_2_index': 3596,
'_input': 1588,
'wa': 9980,
'ac': 1841,
'str_l': 9049,
'int_l': 5210,
'pp': 7680,
'seikai': 8665,
'matigai': 6186,
'correct': 3181,
'mistake': 6473,
'io': 5259,
'stringio': 9063,
'kuku': 5627,
's_set': 8541,
'ansl': 2068,
'mat_sum': 6177,
'xrange': 10224,
'color': 3056,
'dfs': 3540,
'mydict': 6640,
'answer_list': 2077,
'wd': 10011,
'ck': 2946,
'pe': 7513,
'penalty': 7521,
'cp': 3235,
'r_map': 8032,
'r_list': 8031,
'最大公約数': 11957,
'最小公倍数': 11965,
'gcd_num': 4452,
'lcm_num': 5762,
'die': 3583,
'pips': 7550,
'move_die': 6569,
'direction': 3617,
'get_upside': 4530,
'init_die': 5093,
'pip': 7549,
'roll_die': 8333,
'directions': 3618,
'maxs': 6287,
'mins': 6447,
'offset': 7257,
'1000000007': 236,
'matrix': 6189,
'cv': 3342,
'fv': 4399,
'simu': 8784,
'pro': 7809,
'end_time': 3942,
'bool': 2569,
'name': 6724,
'inds': 5078,
'ds': 3779,
'xy': 10236,
'この2行でメモリアクセス省略しないとtleになる': 10573,
'nds': 6770,
'c1': 2717,
'c2': 2718,
'164': 452,
'1415926535898': 403,
'08': 104,
'point': 7603,
'sharp': 8735,
'enumerate_divisors': 3970,
'all_divisors': 1972,
'divisor': 3685,
'calculate_reminder': 2779,
'reminder': 8174,
'sorted_lst': 8880,
'qs': 7951,
'rs': 8414,
'sect': 8640,
'lstrip': 6040,
'水たまりを結合': 12110,
'sum_l': 9136,
'room': 8349,
'_s': 1611,
'bitsum': 2507,
'_bit': 1567,
'bitadd': 2501,
'al': 1940,
'al_to_idx': 1942,
'init': 5091,
'n_': 6678,
'bit': 2494,
'idx': 4946,
'_query': 1607,
'decrement': 3482,
'old': 7277,
'increment': 5052,
'_ans': 1564,
'money': 6546,
'inputs': 5175,
'ss': 8944,
'deg': 3490,
'30': 850,
'180': 483,
'360': 908,
'radians': 8049,
'cos': 3182,
'sin': 8785,
'110000': 312,
'kk': 5595,
'lu': 6049,
'1e18': 537,
'2019': 660,
'相対速度': 12168,
'距離': 12384,
'eval': 4011,
'train': 9638,
'回数': 11669,
'nlogn': 6919,
'時間': 11919,
'nlog': 6918,
'maxk': 6280,
'kaisuu': 5551,
'get_theta': 4523,
'm_angle': 6092,
'h_angle': 4696,
'calculate_vector_distance': 2781,
'theta': 9451,
'dictionary': 3579,
'input_num': 5160,
'lim': 5849,
'200004': 645,
'bin_sum': 2465,
'bin_sum2': 2466,
'pop_num': 7627,
'200005': 646,
'整数': 11865,
'整数複数個': 11871,
'改行区切り': 11846,
'スペース区切り': 11092,
'の行列': 10869,
'abc': 1764,
'table': 9285,
'have': 4753,
'check_p': 2902,
'ws': 10108,
'ct': 3268,
'nt': 7028,
'can_eat': 2791,
'f_time': 4112,
'training': 9640,
'high': 4798,
'low': 6003,
'casefold': 2834,
'end_of_text': 3940,
'syo': 9218,
'amari': 2011,
'dic': 3553,
'kaisu': 5550,
'long': 5969,
'ae': 1903,
'bs_meguru': 2615,
'isok': 5345,
'to': 9563,
'59': 1112,
'shell': 8742,
'leng': 5803,
'lists': 5910,
'eg': 3893,
'xs': 10225,
'seen': 8652,
'init_cmb': 5092,
'nmax': 6925,
'出力の制限': 11478,
'g1': 4410,
'元テーブル': 11397,
'g2': 4411,
'逆元テーブル': 12404,
'inverse': 5247,
'逆元テーブル計算用テーブル': 12405,
'cmb': 2972,
'modn': 6527,
'ci': 2932,
'である個数が': 10683,
'となるような数列の数は': 10735,
'ncm': 6752,
'hm': 4820,
'1cn': 531,
'で足算する': 10707,
'wk': 10059,
'hw': 4864,
'cnt_h': 3008,
'cnt_w': 3015,
'max_h': 6226,
'max_w': 6262,
'h_list': 4704,
'w_list': 9972,
'find_primes': 4240,
'rn': 8324,
'prev': 7720,
'pos': 7639,
'alp': 1989,
'atoi': 2218,
'insort_left': 5191,
'tle': 9491,
'dtype': 3789,
'listをsortする': 5915,
'a0cen': 1636,
'b0cen': 2284,
'nn': 6932,
'a0cen1': 1637,
'b0cen1': 2285,
'ei': 3901,
'mx': 6610,
'su': 9077,
'graph': 4627,
'numofedges': 7114,
'visited': 9908,
'edges': 3883,
'col': 3041,
'adj': 1889,
'maxcolor': 6270,
'ans_1': 2040,
'ans_2': 2041,
'amax': 2014,
'n0': 6661,
'1以上となる最小の2のべき乗数': 598,
'afre': 1907,
'パワーの頻度': 11130,
'は切り捨てなので': 10892,
'rintで四捨五入してから': 8298,
'rint': 8297,
'fft': 4191,
'irfft': 5278,
'rfft': 8267,
'scum': 8618,
'cumsum': 3290,
'累積和': 12229,
'bd': 2393,
'上からm個を取り出したい': 11223,
'searchsorted': 8629,
'価値iを生み出せる組みがm個以上ある': 11367,
'価値iが生み出せる選び方の余分なものを引きたい': 11366,
'ret': 8235,
'numberofcards': 7103,
'far': 4162,
'kyu': 5643,
'dist': 3638,
'vec': 9890,
'morau': 6551,
'factinv': 4136,
'solver': 8863,
'bombs': 2559,
'maxx': 6299,
'maxy': 6300,
'gcd1': 4446,
'cmath': 2971,
'inp': 5125,
'nm': 6923,
'heapify': 4773,
'heappop': 4775,
'heappush': 4776,
'17': 464,
'day': 3432,
'data': 3419,
'全探索なら': 11448,
'4000': 957,
'bit全探索でok': 2512,
'一文字ずつlistへ格納': 11206,
'most': 6553,
'縦の全loop': 12261,
'aa': 1740,
'where': 10032,
'b1': 2286,
'b2': 2288,
'b3': 2289,
'alphabet': 1996,
'loop': 5972,
'i1': 4893,
'i2': 4895,
'96': 1430,
'1100': 311,
'get_dist': 4485,
'du': 3791,
'dv': 3799,
'fullmatch': 4390,
'exame': 4030,
'suma': 9165,
'連想配列': 12413,
'先頭からの番号': 11405,
'余分な量': 11352,
'que': 7965,
'k番目以降は一番左のやつ消していく': 5677,
'cur': 3295,
'examf': 4031,
'si': 8765,
'fact_inv': 4131,
'getdivisor': 4539,
'sum_leaf': 9137,
'before_top': 2410,
'29': 754,
'iim': 4980,
'p25': 7397,
'p26': 7398,
'p26inv': 7399,
'576923081': 1106,
'elem': 3919,
'dq': 3770,
'order': 7330,
'deletelast': 3501,
'koch': 5609,
'start': 8972,
'途中の頂点をa': 12408,
'cとする': 3366,
'rr': 8410,
'segmenttree': 8662,
'非再帰': 12491,
'segment': 8661,
'tree': 9646,
'func': 4392,
'配列の長さ': 12448,
'minだとrmqになる': 6463,
'木の高さhとすると': 12002,
'1までのノード数': 587,
'h段目のノードにアクセスするために使う': 4871,
'ノード': 11119,
'parent': 7475,
'child': 2914,
'1とk': 567,
'bit_length': 2498,
'あたいの初期化': 10470,
'build': 2641,
'setの後に一斉更新': 8727,
'reversed': 8262,
'update': 9791,
'aに更新する': 2267,
'更新ぶんをrootまで更新': 11927,
'のfuncを求める': 10818,
'queries': 7970,
'a2n': 1646,
'createinp': 3250,
'seg': 8658,
'terms': 9396,
'51': 1073,
'goukei': 4616,
'route': 8386,
'obs': 7238,
'length': 5804,
'alpha2num': 1995,
'alpha': 1994,
'item': 5366,
'num2alpha': 7044,
'64': 1169,
'90': 1397,
'ap': 2098,
'bust': 2664,
'win': 10047,
'graph_input': 4630,
'friend': 4358,
'group': 4648,
'で頂点': 10709,
'がどの': 10507,
'に属するかを記録していく': 10809,
'後に': 11782,
'に対して': 10807,
...}
CountVectorizer 함수로 입력되는 단어를 숫자와 매칭시킵니다.
tem.transform([train['code1'][0]]).toarray()
array([[0, 0, 0, ..., 0, 0, 0]])
tem.transform([train['code1'][0]]).toarray().shape
(1, 12552)
code1에 있는 단어 중 단어 집합에 해당하는 단어가 있는 경우에만 1을 출력하는 백터로 변환합니다.
cosine_similarity(tem.transform([train['code1'][0]]), tem.transform([train['code2'][0]]))
array([[0.32871913]])
train['similar'][0]
1
cosine_similarity 함수는 위에서 원-핫 인코딩 형태로 변환 된 단어들을 보고 유사성 여부를 판단합니다.
0.32로 유사성이 일부 있는 것으로 보이는데 실제 두 코드는 유사한 코드 입니다.
class BaselineModel():
def __init__(self, threshold = 0.5):
super(BaselineModel, self).__init__()
self.threshold = threshold
self.vectorizer = CountVectorizer()
def fit(self, code1, code2):
self.vectorizer.fit(code1)
self.vectorizer.fit(code2)
print('Done.')
def predict_proba(self, code1, code2):
code1_vecs = self.vectorizer.transform(code1)
code2_vecs = self.vectorizer.transform(code2)
preds = []
for code1_vec, code2_vec in zip(code1_vecs, code2_vecs):
preds.append(cosine_similarity(code1_vec, code2_vec))
preds = np.reshape(preds, len(preds))
print('Done.')
return preds
@discord_sender(webhook_url="https://discordapp.com/api/webhooks/9810o3fUYfVz2jWg7if")
def predict(self, code1, code2):
preds = self.predict_proba(code1, code2)
preds = np.where(preds > self.threshold, 1, 0)
return preds
트레인 데이터를 이용해 단어 집합을 만들고 테스트 데이터를 원-핫 인코딩 방식으로 변환합니다.
다음으로 cosine_similarity 함수를 사용해 원-핫 인코딩 벡터의 유사성을 검정해 임개값보다 크면 1, 작으면 0을 출력합니다.
딥러닝이라기 보다 단순한 컴퓨터 노가다에 가깝죠. 베이스라인으로 사용하기에 좋은 모델인 것 같습니다.
참고로 함수 위에 @discord_sender(webhook_url = '디스코드 웹 서버')를 입력하면 함수 시작하는 시점, 끝나는 시점이 디스코드 알림으로 옵니다.
학습하는데 오래걸리는 딥러닝 모델에 경우, 핸드폰으로 알림을 받을 수 있으니 알아두면 정말 좋은 기능이 될 것 같아요.
model = BaselineModel(threshold = 0.4)
model.fit(train['code1'], train['code2'])
Done.
preds = model.predict(test['code1'], test['code2'])
Done.
(train['similar']).mean()
0.5011129660545354
(preds).mean()
0.5092877017250974
threshold = 0.5일때 약 0.3, threshold = 0.3일때 약 0.7, threshold = 0.4일때 약 0.5를 가지는 것을 확인했습니다.
실제 트레인 데이터의 유사성이 1일 확률이 0.5에 가까움으로 임계값은 0.4로 잡겠습니다.
(대회는 ACCURACY 기준)
predtrain = model.predict(train['code1'], train['code2'])
(predtrain == train['similar']).mean()
Done.
0.7125765164162493
임계값 0.4를 사용해 트레인 데이터를 예측에 사용하면 약 0.71에 정확성을 보입니다.
사실 거창한 딥러닝을 사용하지 않아도 쓰는 단어가 얼추 비슷하면 유사성 판단은 어느정도는 하는 것을 알 수 있습니다.
sample_submission['similar'] = preds
sample_submission.to_csv('dacon_codes.csv', index = False)
# 결과 : 0.688
!pip install transformers
!pip install transformers datasets
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.19.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.7.0)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.12.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.3)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.19.2)
Collecting datasets
Downloading datasets-2.2.2-py3-none-any.whl (346 kB)
|████████████████████████████████| 346 kB 4.3 MB/s
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.3)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.7.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)
Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.12.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)
Collecting xxhash
Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
|████████████████████████████████| 212 kB 10.1 MB/s
Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)
Collecting fsspec[http]>=2021.05.0
Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
|████████████████████████████████| 140 kB 10.9 MB/s
Collecting aiohttp
Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
|████████████████████████████████| 1.1 MB 12.0 MB/s
Collecting dill<0.3.5
Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)
|████████████████████████████████| 86 kB 6.1 MB/s
Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)
Collecting responses<0.19
Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1
Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)
|████████████████████████████████| 127 kB 28.0 MB/s
Collecting aiosignal>=1.1.2
Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting asynctest==0.13.0
Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0)
Collecting yarl<2.0,>=1.0
Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)
|████████████████████████████████| 271 kB 28.3 MB/s
Collecting async-timeout<5.0,>=4.0.0a3
Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Collecting multidict<7.0,>=4.5
Downloading multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (94 kB)
|████████████████████████████████| 94 kB 1.8 MB/s
Collecting frozenlist>=1.1.1
Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)
|████████████████████████████████| 144 kB 29.7 MB/s
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)
Installing collected packages: multidict, frozenlist, yarl, urllib3, asynctest, async-timeout, aiosignal, fsspec, dill, aiohttp, xxhash, responses, datasets
Attempting uninstall: urllib3
Found existing installation: urllib3 1.24.3
Uninstalling urllib3-1.24.3:
Successfully uninstalled urllib3-1.24.3
Attempting uninstall: dill
Found existing installation: dill 0.3.5.1
Uninstalling dill-0.3.5.1:
Successfully uninstalled dill-0.3.5.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires tornado~=5.1.0; python_version >= "3.0", but you have tornado 6.1 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
Successfully installed aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 datasets-2.2.2 dill-0.3.4 frozenlist-1.3.0 fsspec-2022.5.0 multidict-6.0.2 responses-0.18.0 urllib3-1.25.11 xxhash-3.0.0 yarl-1.7.2
from transformers import AutoTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset, load_metric
import torch
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = "klue/bert-base" # 'microsoft/graphcodebert-base'
MAX_LEN = 256
dataset = load_dataset('csv', data_files = path+'sample_train.csv')['train']
tokenizer = AutoTokenizer.from_pretrained(model)
Using custom data configuration default-e6c40baceae51225
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-e6c40baceae51225/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)
https://huggingface.co/klue/bert-base/resolve/main/tokenizer_config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpjiwfrzjb
storing https://huggingface.co/klue/bert-base/resolve/main/tokenizer_config.json in cache at /root/.cache/huggingface/transformers/f8f71eb411bb03f57b455cfb1b4e04ae124201312e67a3ad66e0a92d0c228325.78871951edcb66032caa0a9628d77b3557c23616c653dacdb7a1a8f33011a843
creating metadata file for /root/.cache/huggingface/transformers/f8f71eb411bb03f57b455cfb1b4e04ae124201312e67a3ad66e0a92d0c228325.78871951edcb66032caa0a9628d77b3557c23616c653dacdb7a1a8f33011a843
https://huggingface.co/klue/bert-base/resolve/main/config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp3hhezuli
storing https://huggingface.co/klue/bert-base/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
creating metadata file for /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
loading configuration file https://huggingface.co/klue/bert-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
Model config BertConfig {
"_name_or_path": "klue/bert-base",
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.19.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 32000
}
https://huggingface.co/klue/bert-base/resolve/main/vocab.txt not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpf76t5zqa
storing https://huggingface.co/klue/bert-base/resolve/main/vocab.txt in cache at /root/.cache/huggingface/transformers/1a36e69d48a008e522b75e43693002ffc8b6e6df72de7c53412c23466ec165eb.085110015ec67fc02ad067f712a7c83aafefaf31586a3361dd800bcac635b456
creating metadata file for /root/.cache/huggingface/transformers/1a36e69d48a008e522b75e43693002ffc8b6e6df72de7c53412c23466ec165eb.085110015ec67fc02ad067f712a7c83aafefaf31586a3361dd800bcac635b456
https://huggingface.co/klue/bert-base/resolve/main/tokenizer.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp5jkm4tpg
storing https://huggingface.co/klue/bert-base/resolve/main/tokenizer.json in cache at /root/.cache/huggingface/transformers/310a974e892b181d75eed58b545cc0592d066ae4ef35cc760ea92e9b0bf65b3b.74f7933572f937b11a02b2cfb4e88a024059be36c84f53241b85b1fec49e21f7
creating metadata file for /root/.cache/huggingface/transformers/310a974e892b181d75eed58b545cc0592d066ae4ef35cc760ea92e9b0bf65b3b.74f7933572f937b11a02b2cfb4e88a024059be36c84f53241b85b1fec49e21f7
https://huggingface.co/klue/bert-base/resolve/main/special_tokens_map.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp34fn1p8b
storing https://huggingface.co/klue/bert-base/resolve/main/special_tokens_map.json in cache at /root/.cache/huggingface/transformers/aeaaa3afd086a040be912f92ffe7b5f85008b744624f4517c4216bcc32b51cf0.054ece8d16bd524c8a00f0e8a976c00d5de22a755ffb79e353ee2954d9289e26
creating metadata file for /root/.cache/huggingface/transformers/aeaaa3afd086a040be912f92ffe7b5f85008b744624f4517c4216bcc32b51cf0.054ece8d16bd524c8a00f0e8a976c00d5de22a755ffb79e353ee2954d9289e26
loading file https://huggingface.co/klue/bert-base/resolve/main/vocab.txt from cache at /root/.cache/huggingface/transformers/1a36e69d48a008e522b75e43693002ffc8b6e6df72de7c53412c23466ec165eb.085110015ec67fc02ad067f712a7c83aafefaf31586a3361dd800bcac635b456
loading file https://huggingface.co/klue/bert-base/resolve/main/tokenizer.json from cache at /root/.cache/huggingface/transformers/310a974e892b181d75eed58b545cc0592d066ae4ef35cc760ea92e9b0bf65b3b.74f7933572f937b11a02b2cfb4e88a024059be36c84f53241b85b1fec49e21f7
loading file https://huggingface.co/klue/bert-base/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/klue/bert-base/resolve/main/special_tokens_map.json from cache at /root/.cache/huggingface/transformers/aeaaa3afd086a040be912f92ffe7b5f85008b744624f4517c4216bcc32b51cf0.054ece8d16bd524c8a00f0e8a976c00d5de22a755ffb79e353ee2954d9289e26
loading file https://huggingface.co/klue/bert-base/resolve/main/tokenizer_config.json from cache at /root/.cache/huggingface/transformers/f8f71eb411bb03f57b455cfb1b4e04ae124201312e67a3ad66e0a92d0c228325.78871951edcb66032caa0a9628d77b3557c23616c653dacdb7a1a8f33011a843
loading configuration file https://huggingface.co/klue/bert-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
Model config BertConfig {
"_name_or_path": "klue/bert-base",
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.19.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 32000
}
load_dataset 함수 : csv 파일을 가공하기 쉬운 데이터 셋으로 만들어줍니다.
AutoTokenizer내 from_pretrained 함수에서 프리트레인 모델 이름만 입력하면 자동으로 토크나이징이 됩니다.
def example_fn(examples):
outputs = tokenizer(examples['code1'], examples['code2'], padding = True, max_length = MAX_LEN, truncation = True)
if 'similar' in examples:
outputs['labels'] = examples['similar']
return outputs
dataset = dataset.map(example_fn, remove_columns = ['code1', 'code2', 'similar'])
dataset = dataset.train_test_split(0.1)
dataset
DatasetDict({
train: Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 16173
})
test: Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 1797
})
})
dataset 내 map 함수 : 원소별로 입력된 함수를 적용합니다.
dataset 내 train_test_split 함수 : 트레인-테스트 데이터 셋으로 분할한 딕셔너리를 만듭니다.
_collator = DataCollatorWithPadding(tokenizer = tokenizer) # 아래 사진으로 해당함수 설명
_metric = load_metric('glue', 'sst2') # 측정함수도 huggingface 내 존재
# https://huggingface.co/docs/datasets/v1.0.1/loading_metrics.html 참고문서
def metric_fn(p): # 측정함수
preds, labels = p
output = _metric.compute(references = labels, predictions = np.argmax(preds, axis = -1))
return output
model = BertForSequenceClassification.from_pretrained(model)
args = TrainingArguments(
'runs/',
per_device_train_batch_size = 32,
num_train_epochs = 3,
do_train = True,
do_eval = True,
save_strategy = 'epoch',
logging_strategy = 'epoch',
evaluation_strategy = 'epoch',
)
loading configuration file https://huggingface.co/klue/bert-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
Model config BertConfig {
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.19.2",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 32000
}
https://huggingface.co/klue/bert-base/resolve/main/pytorch_model.bin not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpb8lhvjfk
storing https://huggingface.co/klue/bert-base/resolve/main/pytorch_model.bin in cache at /root/.cache/huggingface/transformers/05b36ee62545d769939a7746eca739b844a40a7a7553700f110b58b28ed6a949.7cb231256a5dbe886e12b902d05cb1241f330d8c19428508f91b2b28c1cfe0b6
creating metadata file for /root/.cache/huggingface/transformers/05b36ee62545d769939a7746eca739b844a40a7a7553700f110b58b28ed6a949.7cb231256a5dbe886e12b902d05cb1241f330d8c19428508f91b2b28c1cfe0b6
loading weights file https://huggingface.co/klue/bert-base/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/05b36ee62545d769939a7746eca739b844a40a7a7553700f110b58b28ed6a949.7cb231256a5dbe886e12b902d05cb1241f330d8c19428508f91b2b28c1cfe0b6
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at klue/bert-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
trainer = Trainer(
model = model,
args = args,
data_collator = _collator,
train_dataset = dataset['train'],
eval_dataset = dataset['test'],
tokenizer = tokenizer,
compute_metrics = metric_fn
)
@discord_sender(webhook_url="https://discordapp.com/api/webhooks/98101yo3fUYfVz2jWg7if")
def tem():
trainer.train()
tem()
***** Running training ***** Num examples = 16173 Num Epochs = 3 Instantaneous batch size per device = 32 Total train batch size (w. parallel, distributed & accumulation) = 32 Gradient Accumulation steps = 1 Total optimization steps = 1518
| Epoch | Training Loss | Validation Loss | Accuracy |
|---|---|---|---|
| 1 | 0.304100 | 0.188209 | 0.920423 |
| 2 | 0.117700 | 0.130098 | 0.960490 |
| 3 | 0.036200 | 0.093612 | 0.978297 |
</div> </div>
***** Running Evaluation ***** Num examples = 1797 Batch size = 8 Saving model checkpoint to runs/checkpoint-506 Configuration saved in runs/checkpoint-506/config.json Model weights saved in runs/checkpoint-506/pytorch_model.bin tokenizer config file saved in runs/checkpoint-506/tokenizer_config.json Special tokens file saved in runs/checkpoint-506/special_tokens_map.json ***** Running Evaluation ***** Num examples = 1797 Batch size = 8 Saving model checkpoint to runs/checkpoint-1012 Configuration saved in runs/checkpoint-1012/config.json Model weights saved in runs/checkpoint-1012/pytorch_model.bin tokenizer config file saved in runs/checkpoint-1012/tokenizer_config.json Special tokens file saved in runs/checkpoint-1012/special_tokens_map.json ***** Running Evaluation ***** Num examples = 1797 Batch size = 8 Saving model checkpoint to runs/checkpoint-1518 Configuration saved in runs/checkpoint-1518/config.json Model weights saved in runs/checkpoint-1518/pytorch_model.bin tokenizer config file saved in runs/checkpoint-1518/tokenizer_config.json Special tokens file saved in runs/checkpoint-1518/special_tokens_map.json Training completed. Do not forget to share your model on huggingface.co/models =)
test_dataset = load_dataset('csv', data_files = path+'test.csv')['train']
test_dataset = test_dataset.map(example_fn, remove_columns = ['code1', 'code2'])
predictions = trainer.predict(test_dataset)
sample_submission['similar'] = np.argmax(predictions.predictions, axis = -1)
sample_submission.to_csv('dacon_codes2.csv', index = False)
# 결과 : 0.787
Using custom data configuration default-6692cc772abf77e3 Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-6692cc772abf77e3/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519) The following columns in the test set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: pair_id. If pair_id are not expected by `BertForSequenceClassification.forward`, you can safely ignore this message. ***** Running Prediction ***** Num examples = 179700 Batch size = 8
predictions.predictions
array([[ 4.3174033, -3.926483 ],
[-4.4421177, 4.0519753],
[-4.2312655, 3.7239847],
...,
[ 2.8522801, -2.7274592],
[-4.4676304, 4.039054 ],
[ 2.7921119, -2.535519 ]], dtype=float32)
sample_submission['similar'].mean()
0.6730996104618809