Open In Colab

구글 드라이브 연동 및 패키지 설치

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

구글 드라이브와 연동해 파일을 불러옵니다.

import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity

import warnings
warnings.filterwarnings("ignore")

path = '/content/drive/MyDrive/coding/'

train = pd.read_csv(path + 'sample_train.csv')
test = pd.read_csv(path + 'test.csv')
sample_submission = pd.read_csv(path + 'sample_submission.csv')
train.head()
code1 code2 similar
0 flag = "go"\ncnt = 0\nwhile flag == "go":\n ... # Python 3+\n#--------------------------------... 1
1 b, c = map(int, input().split())\n\nprint(b * c) import numpy as np\n\nn = int(input())\na = np... 0
2 import numpy as np\nimport sys\nread = sys.std... N, M = map(int, input().split())\nif M%2 != 0:... 0
3 b, c = map(int, input().split())\n\nprint(b * c) n,m=map(int,input().split())\nh=list(map(int,i... 0
4 s=input()\nt=input()\nans=0\nfor i in range(le... import math\na,b,h,m=map(int,input().split())\... 0

필수 패키지를 불러오고 데이터를 불러옵니다.

!pip install transformers knockknock

from knockknock import discord_sender
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
     |████████████████████████████████| 4.2 MB 5.3 MB/s 
Collecting knockknock
  Downloading knockknock-0.1.8.1-py3-none-any.whl (28 kB)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.3)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
     |████████████████████████████████| 6.6 MB 48.3 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
     |████████████████████████████████| 596 kB 47.0 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
     |████████████████████████████████| 86 kB 6.1 MB/s 
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Collecting keyring
  Downloading keyring-23.5.1-py3-none-any.whl (33 kB)
Collecting yagmail>=0.11.214
  Downloading yagmail-0.15.277-py2.py3-none-any.whl (17 kB)
Collecting twilio
  Downloading twilio-7.9.1-py2.py3-none-any.whl (1.4 MB)
     |████████████████████████████████| 1.4 MB 45.6 MB/s 
Collecting matrix-client
  Downloading matrix_client-0.4.0-py2.py3-none-any.whl (43 kB)
     |████████████████████████████████| 43 kB 2.4 MB/s 
Collecting python-telegram-bot
  Downloading python_telegram_bot-13.12-py3-none-any.whl (511 kB)
     |████████████████████████████████| 511 kB 70.2 MB/s 
Collecting premailer
  Downloading premailer-3.10.0-py2.py3-none-any.whl (19 kB)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)
Collecting jeepney>=0.4.2
  Downloading jeepney-0.8.0-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 6.2 MB/s 
Collecting SecretStorage>=3.2
  Downloading SecretStorage-3.3.2-py3-none-any.whl (15 kB)
Collecting cryptography>=2.0
  Downloading cryptography-37.0.2-cp36-abi3-manylinux_2_24_x86_64.whl (4.0 MB)
     |████████████████████████████████| 4.0 MB 45.9 MB/s 
Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.7/dist-packages (from cryptography>=2.0->SecretStorage>=3.2->keyring->knockknock) (1.15.0)
Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.12->cryptography>=2.0->SecretStorage>=3.2->keyring->knockknock) (2.21)
Requirement already satisfied: urllib3~=1.21 in /usr/local/lib/python3.7/dist-packages (from matrix-client->knockknock) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: lxml in /usr/local/lib/python3.7/dist-packages (from premailer->yagmail>=0.11.214->knockknock) (4.2.6)
Requirement already satisfied: cachetools in /usr/local/lib/python3.7/dist-packages (from premailer->yagmail>=0.11.214->knockknock) (4.2.4)
Collecting cssutils
  Downloading cssutils-2.4.0-py3-none-any.whl (404 kB)
     |████████████████████████████████| 404 kB 68.7 MB/s 
Collecting cssselect
  Downloading cssselect-1.1.0-py2.py3-none-any.whl (16 kB)
Collecting cachetools
  Downloading cachetools-4.2.2-py3-none-any.whl (11 kB)
Collecting tornado>=6.1
  Downloading tornado-6.1-cp37-cp37m-manylinux2010_x86_64.whl (428 kB)
     |████████████████████████████████| 428 kB 23.9 MB/s 
Requirement already satisfied: pytz>=2018.6 in /usr/local/lib/python3.7/dist-packages (from python-telegram-bot->knockknock) (2022.1)
Collecting APScheduler==3.6.3
  Downloading APScheduler-3.6.3-py2.py3-none-any.whl (58 kB)
     |████████████████████████████████| 58 kB 7.3 MB/s 
Requirement already satisfied: tzlocal>=1.2 in /usr/local/lib/python3.7/dist-packages (from APScheduler==3.6.3->python-telegram-bot->knockknock) (1.5.1)
Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from APScheduler==3.6.3->python-telegram-bot->knockknock) (1.15.0)
Requirement already satisfied: setuptools>=0.7 in /usr/local/lib/python3.7/dist-packages (from APScheduler==3.6.3->python-telegram-bot->knockknock) (57.4.0)
Collecting PyJWT<3.0.0,>=2.0.0
  Downloading PyJWT-2.4.0-py3-none-any.whl (18 kB)
Installing collected packages: jeepney, cssutils, cssselect, cryptography, cachetools, tornado, SecretStorage, pyyaml, PyJWT, premailer, APScheduler, yagmail, twilio, tokenizers, python-telegram-bot, matrix-client, keyring, huggingface-hub, transformers, knockknock
  Attempting uninstall: cachetools
    Found existing installation: cachetools 4.2.4
    Uninstalling cachetools-4.2.4:
      Successfully uninstalled cachetools-4.2.4
  Attempting uninstall: tornado
    Found existing installation: tornado 5.1.1
    Uninstalling tornado-5.1.1:
      Successfully uninstalled tornado-5.1.1
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires tornado~=5.1.0; python_version >= "3.0", but you have tornado 6.1 which is incompatible.
Successfully installed APScheduler-3.6.3 PyJWT-2.4.0 SecretStorage-3.3.2 cachetools-4.2.2 cryptography-37.0.2 cssselect-1.1.0 cssutils-2.4.0 huggingface-hub-0.7.0 jeepney-0.8.0 keyring-23.5.1 knockknock-0.1.8.1 matrix-client-0.4.0 premailer-3.10.0 python-telegram-bot-13.12 pyyaml-6.0 tokenizers-0.12.1 tornado-6.1 transformers-4.19.2 twilio-7.9.1 yagmail-0.15.277

디스코드 서버와 연동하는 패키지를 다운로드 합니다. 사용법은 밑에서 설명합니다.

print(train.shape)
print(test.shape)
(17970, 3)
(179700, 3)

BOW 방법

tem = CountVectorizer()
tem.fit(train['code1'])
tem.vocabulary_
{'flag': 4281,
 'go': 4603,
 'cnt': 2991,
 'while': 10034,
 'int': 5201,
 'input': 5137,
 'if': 4958,
 'stop': 9031,
 'else': 3928,
 'print': 7773,
 'case': 2833,
 'str': 9035,
 'map': 6152,
 'split': 8907,
 'import': 5023,
 'numpy': 7117,
 'as': 2180,
 'np': 7004,
 'sys': 9222,
 'read': 8108,
 'stdin': 9007,
 'buffer': 2639,
 'readline': 8129,
 'from': 4363,
 'numba': 7090,
 'njit': 6906,
 'def': 3485,
 'getinputs': 4546,
 'cs': 3257,
 'array': 2158,
 'int32': 5203,
 '26': 725,
 'reshape': 8212,
 'return': 8247,
 'i8': 4903,
 'i4': 4898,
 'cache': 2753,
 'true': 9668,
 '_compute_score1': 1572,
 'out': 7356,
 'score': 8599,
 'last': 5725,
 'zeros': 10428,
 'for': 4327,
 'in': 5026,
 'range': 8075,
 'len': 5793,
 'sum': 9115,
 'step1': 9011,
 'max_score': 6252,
 '10000000': 225,
 'best_i': 2428,
 'append': 2103,
 'pop': 7622,
 'output': 7362,
 'join': 5476,
 'astype': 2199,
 'tolist': 9583,
 'ans': 2035,
 'min': 6387,
 'abs': 1835,
 'rev': 8254,
 'false': 4161,
 'list': 5884,
 'mini': 6435,
 '10': 219,
 'count': 3194,
 'se': 8622,
 'set': 8700,
 'tot': 9595,
 'add': 1875,
 'yes': 10329,
 'no': 6941,
 'setrecursionlimit': 8719,
 'week': 10019,
 'sun': 9180,
 'mon': 6544,
 'tue': 9686,
 'wed': 10018,
 'thu': 9464,
 'fri': 4357,
 'sat': 8572,
 'main': 6121,
 'decode': 3478,
 '__name__': 1532,
 '__main__': 1529,
 '10000': 222,
 'break': 2604,
 'and': 2025,
 'exit': 4053,
 'power': 7673,
 'mod': 6498,
 'bi': 2449,
 'format': 4331,
 '2進数': 846,
 'res': 8203,
 'reverse': 8258,
 'temp1': 9363,
 'temp2': 9364,
 'temp': 9361,
 'x_': 10138,
 'a1': 1640,
 'a2': 1643,
 'a3': 1648,
 'an': 2020,
 'at': 2203,
 'となるt': 10732,
 'sのペアの数は': 9233,
 'かつ': 10484,
 'sのペアの数と等しく': 9232,
 'sのペアの数と等しい': 9231,
 'ms': 6577,
 'ps': 7830,
 'msc': 6579,
 'psc': 7833,
 'mi': 6368,
 'pi': 7542,
 'get': 4475,
 'ii': 4973,
 'li': 5827,
 'd2': 3378,
 'dict': 3572,
 'dp0': 3737,
 'dp1': 3738,
 'dp2': 3739,
 'dp3': 3740,
 'max': 6202,
 'math': 6185,
 'string': 9059,
 'itertools': 5377,
 'fractions': 4351,
 'heapq': 4778,
 'collections': 3052,
 're': 8105,
 'bisect': 2487,
 'random': 8064,
 'time': 9473,
 'copy': 3175,
 'functools': 4395,
 'deque': 3524,
 'inf': 5079,
 '20': 635,
 '998244353': 1449,
 'dr': 3774,
 'dc': 3445,
 'li_': 5830,
 'lf': 5819,
 'float': 4303,
 'ls': 6024,
 'dp': 3736,
 'reduce': 8149,
 'gcd': 4445,
 'la': 5708,
 'inv': 5238,
 'pow': 7661,
 'lcm': 5753,
 'addmod': 1886,
 'answer': 2075,
 'sums': 9176,
 'try': 9673,
 'except': 4034,
 'eoferror': 3976,
 'a_mod': 1713,
 'hn': 4825,
 'strip': 9065,
 'mapint': 6158,
 'accumulate': 1855,
 'open': 7313,
 'a_acc': 1665,
 'initial': 5107,
 'min_diff': 6395,
 'left': 5777,
 'right': 8278,
 'となるものがいくつあるか': 10734,
 'subdp': 9086,
 'target': 9318,
 'h1': 4689,
 'm1': 6078,
 'h2': 4691,
 'm2': 6079,
 '60': 1150,
 'usr': 9819,
 'bin': 2460,
 'env': 3971,
 'python3': 7876,
 'chain': 2869,
 'solve': 8852,
 'current': 3317,
 '一度も通ったことがない': 11203,
 'ステップ目に通った事を記録': 11090,
 'loop_len': 5978,
 'ループの長さ': 11191,
 'rest': 8219,
 '残り長さ': 12102,
 '残り長さをループの余剰にする': 12103,
 'tokens': 9581,
 'line': 5857,
 'next': 6834,
 'type': 9718,
 'counter': 3219,
 'decimal': 3475,
 'numbers': 7105,
 'book': 2565,
 'bit全探索': 2511,
 'xls': 10199,
 'cost': 3185,
 'x1': 10120,
 'y1': 10287,
 'int64': 5204,
 'u4': 9732,
 'uint32': 9759,
 'argv': 2149,
 'online_judge': 7299,
 'pycc': 7865,
 'cc': 2843,
 'types': 9720,
 'my_module': 6626,
 'export': 4061,
 'factorization': 4145,
 'n_max': 6703,
 'sqrt': 8926,
 'sort': 8866,
 'p_max': 7437,
 'primes_num': 7770,
 'shape': 8734,
 'a_start': 1732,
 'check': 2895,
 'stack': 8961,
 'empty': 3932,
 'p_stack': 7449,
 'compile': 3123,
 'in_file': 5031,
 'fromstring': 4367,
 'sep': 8690,
 'pairwise': 7464,
 'coprime': 3174,
 'setwise': 8725,
 'not': 6974,
 'dev': 3536,
 'product': 7822,
 'repeat': 8191,
 'lr': 6015,
 'csum': 3267,
 'encoding': 3937,
 'utf': 9828,
 'bisect_left': 2488,
 'これで二部探索の大小検索が行える': 10580,
 '最小公倍数などはこっち': 11966,
 '10進数で考慮できる': 308,
 '再帰回数上限はでdefault1000': 11465,
 'abssort': 1839,
 'sorted': 8873,
 'key': 5575,
 'lambda': 5719,
 'tmps': 9554,
 'a_abs': 1664,
 'tmp_1': 9504,
 'deepcopy': 3483,
 'tmp_2': 9507,
 'tmp_2_': 9508,
 'tmp_1_': 9505,
 'を一つ消す': 10962,
 'remove': 8176,
 'を一つたす': 10961,
 'がなかった時': 10508,
 'を消して': 10994,
 'を追加': 11011,
 'tmp_1_m': 9506,
 'tmp_2_m': 9509,
 'pass': 7492,
 'elif': 3924,
 '正負に関係なくsort': 12092,
 'a_p': 1718,
 'plusを入れる': 7584,
 'a_n': 1714,
 'を入れる': 10973,
 'ok': 7267,
 '正の数が存在している': 12084,
 '選択肢がない時': 12440,
 '負の数が偶数個': 12367,
 '奇数個選ぶ': 11713,
 'position': 7646,
 'pairs': 7463,
 'cnt_p': 3011,
 'cnt_n': 3010,
 'move': 6566,
 'enumerate': 3969,
 '_update_score': 1625,
 '_random_update': 1610,
 'randint': 8063,
 'new_score': 6819,
 '_random_swap': 1609,
 'delta': 3504,
 'd1': 3375,
 'or': 7325,
 'step2': 9012,
 '48': 1011,
 'rand': 8058,
 '13': 376,
 'prime_numbers': 7765,
 'n以下の素数列挙': 7202,
 'eratosthenes': 3991,
 'prime_list': 7761,
 'prime_factorization': 7754,
 'factors': 4152,
 'tmp_n': 9529,
 'ceil': 2856,
 'count_divide': 3205,
 'divmod': 3691,
 'a_b_max': 1675,
 'a_b_min': 1676,
 'max_factors': 6224,
 'min_factors': 6398,
 'a_s': 1726,
 'a_r': 1722,
 'b_s': 2332,
 'b_r': 2330,
 'unsafe': 9783,
 'safe': 8551,
 'tmp': 9500,
 'raw_input': 8087,
 'linked': 5874,
 'defaultdict': 3487,
 'i2group': 4896,
 'gid': 4583,
 'get_root': 4514,
 'get_groups': 4491,
 'ra': 8040,
 'rb': 8090,
 'continue': 3162,
 'n_connected': 6685,
 'n_group': 6694,
 'values': 9876,
 'clear': 2951,
 'zip': 10438,
 'iter': 5371,
 'class': 2948,
 'facts': 4154,
 'max_num': 6242,
 '__init__': 1525,
 'self': 8682,
 'fact': 4128,
 'power_func': 7674,
 'comb': 3077,
 'log': 5954,
 'r26': 7999,
 'r25': 7998,
 '25': 717,
 'total': 9601,
 's_n': 8531,
 'end': 3938,
 'is': 5282,
 'dice': 3558,
 'ns': 7018,
 'ew': 4020,
 'question': 7973,
 'top': 9587,
 'front': 4368,
 'settop': 8724,
 'sides': 8770,
 'index': 5060,
 'tail': 9291,
 'insert': 5182,
 'dnum': 3711,
 'readlines': 8130,
 'lower': 6004,
 'rstrip': 8427,
 'word': 10080,
 'abcdefghijklmnopqrstuvwxyz': 1823,
 'insertionsort': 5187,
 'lst': 6035,
 'pajew': 7465,
 '1000000': 224,
 'find': 4229,
 'unite': 9778,
 'grou': 4646,
 'arr': 2155,
 'gro': 4644,
 'gro_no': 4645,
 'popleft': 7635,
 'distance': 3649,
 'returns': 8252,
 'minkowski': 6440,
 'of': 7253,
 'vactor': 9862,
 'chebyshev': 2894,
 '6f': 1202,
 '000000': 3,
 '449490': 991,
 '154435': 437,
 'run': 8437,
 'dim': 3606,
 'flake8': 4295,
 'noqa': 6967,
 'building_a': 2649,
 '11': 309,
 'building_b': 2650,
 'building_c': 2651,
 'building_d': 2652,
 'stdout': 9009,
 'write': 10104,
 'p_list': 7432,
 'c_list': 2738,
 'score1': 8600,
 'score2': 8601,
 'div': 3667,
 'getn': 4554,
 'getnm': 4555,
 'getlist': 4552,
 'getarray': 4534,
 'intn': 5227,
 'rand_n': 8061,
 'ran1': 8056,
 'ran2': 8057,
 'rand_list': 8060,
 'rantime': 8079,
 'rand_ints_nodup': 8059,
 'rand_query': 8062,
 'r_query': 8036,
 'n_q': 6709,
 'combinations': 3089,
 'permutations': 7529,
 'operator': 7318,
 'mul': 6589,
 'bisect_right': 2492,
 '1000000000': 227,
 'code': 3028,
 'limit回までコストカットできる': 5852,
 'knapsack_6': 5603,
 'upper': 9796,
 'limit': 5851,
 'weight': 10021,
 'value': 9869,
 'ボーナスでコスト1にするのを使ったか': 11148,
 'コストカットできる時': 11056,
 'できない時': 10685,
 '1000': 221,
 'back': 2344,
 'namedtuple': 6725,
 'uf': 9753,
 'rank': 8078,
 'size': 8801,
 'root': 8359,
 'same': 8565,
 'friends': 4359,
 'block': 2537,
 '153_b': 433,
 'a0': 1631,
 '102': 264,
 'log2': 5956,
 'logn': 5966,
 'db': 3438,
 'dbs': 3441,
 'now': 6981,
 'dll': 3698,
 'command': 3108,
 'appendleft': 2107,
 'delete': 3497,
 'deletefirst': 3499,
 'coding': 3033,
 'sr': 8934,
 'ir': 5276,
 '左からgreedyに': 11744,
 'monsters': 6549,
 'bomb': 2556,
 'attack': 2219,
 'cook': 3169,
 'your': 10360,
 'dish': 3632,
 'here': 4790,
 '400': 956,
 '599': 1115,
 '600': 1151,
 '799': 1279,
 '800': 1312,
 '999': 1450,
 '1199': 338,
 '1200': 345,
 '1399': 389,
 '1400': 391,
 '1599': 444,
 '1600': 446,
 '1799': 480,
 '1800': 484,
 '1999': 520,
 'merge': 6348,
 'mid': 6376,
 'global': 4594,
 'n1': 6664,
 'n2': 6669,
 'mergesort': 6352,
 'e_red_scarf': 3852,
 'mask': 6166,
 '1e9': 539,
 'coefs': 3036,
 '14': 390,
 '22': 690,
 '33': 882,
 '46': 999,
 '15': 421,
 'hon': 4834,
 'pon': 7617,
 'bon': 2563,
 'num': 7039,
 'ca': 2752,
 'val': 9863,
 'items': 5370,
 'sa': 8547,
 'ng': 6878,
 '方針': 11892,
 '各文字列の出現回数を数え': 11628,
 '出現回数が最大なる文字列を昇順に出力する': 11485,
 'リスト': 11168,
 'は辞書型のサブクラスであり': 10900,
 'キーに要素': 11033,
 '値に出現回数という形式': 11381,
 'most_common': 6554,
 '要素': 12305,
 '出現回数': 11481,
 'というタプルを出現回数順に並べたリスト': 10713,
 'max_count': 6215,
 '最大の出現回数': 11945,
 '出現回数が最も多い単語を集計する': 11484,
 '昇順にソートして出力': 11903,
 'resolve': 8216,
 '300000': 854,
 '200000': 638,
 '100000': 223,
 'bubble_sort_aoj': 2629,
 'nums': 7120,
 'バブルソート': 11123,
 '隣接項の比較': 12482,
 'fibo': 4199,
 'result': 8225,
 'groupby': 4659,
 '100': 220,
 '101': 258,
 '解説と': 12331,
 '13355391': 380,
 'を参考に実装予定': 10981,
 'lonlieness': 5971,
 'ab': 1757,
 'bad_a': 2347,
 'bad_b': 2348,
 '_gcd': 1582,
 'setdefault': 8717,
 '仲の悪いグループも登録しておく': 11334,
 'pair': 7459,
 'keys': 5584,
 '仲の悪いグループは隣り合っているので飛び石で計算': 11333,
 'gourp1': 4617,
 'から1匹以上選ぶパターン': 10488,
 'group2': 4650,
 'どちらからも選ばないパターン計算する': 10758,
 'group1': 4649,
 'は仲が悪いので同時に選ばれることはない': 10891,
 'group_num': 4657,
 'badgroup_num': 2350,
 '全員と仲が悪いイワシのパターンを足し': 11446,
 'すべてのイワシを選ばないパターンを除外': 10608,
 'mn': 6490,
 'diff': 3589,
 '500': 1050,
 'isupper': 5358,
 'sum1': 9116,
 'sum2': 9117,
 'del': 3496,
 'shellsort': 8744,
 '262913': 727,
 '65921': 1178,
 '16577': 456,
 '4193': 974,
 '1073': 286,
 '281': 740,
 '77': 1266,
 '23': 700,
 'n以下が確定していて': 7198,
 '0以外の数をk個使ったとき': 211,
 'n以下が確定していないときの0以外の数の個数': 7199,
 '0を使うことで0以外の数が増えないパターン': 204,
 '0以外の数を使うことで0以外の数が増えるパターン': 212,
 '今回でn以下が確定するパターン': 11301,
 '確定する前までに0以外の数を何個使っているか': 12177,
 '今回でn以下が確定することはない': 11300,
 'すでにk個以上の0以外の数を使っているとき': 10602,
 'ちょうどk個使っている時': 10670,
 '0を使うしかない': 205,
 'n以下を確定させるためaは使えない': 7203,
 'にぶたん': 10800,
 'n人のメンバーそれぞれが完食にかかる時間のうち最大値をxに以下にできるか': 7195,
 '12': 343,
 'need_training': 6782,
 'cond': 3138,
 'rotate': 8374,
 'deck': 3477,
 'nnnn': 6938,
 'query': 7971,
 'num_of_sug': 7072,
 'sug': 9104,
 'tuple': 9688,
 'liar': 5834,
 'honest': 4836,
 'sug_tmp': 9106,
 'sug_': 9105,
 'hi': 4796,
 'hihi': 4805,
 'hihihi': 4806,
 'hihihihi': 4807,
 'hihihihihi': 4808,
 'statistics': 8999,
 'amed': 2015,
 'median': 6328,
 'bmed': 2547,
 'sgn': 8731,
 'pfugou': 7536,
 '選んでないのが2個以下': 12437,
 'よってlen': 10939,
 '前述で処理済み': 11546,
 'なのでここでやることはない': 10778,
 'b_num': 2328,
 'b_best': 2309,
 'a_num': 1717,
 'maxmize': 6282,
 'none': 6964,
 'all_max': 1977,
 'st': 8953,
 'scores': 8612,
 'num_elem': 7058,
 'all_sum': 1980,
 'max_': 6205,
 'max_r': 6249,
 'temp_r': 9379,
 'temp_max': 9377,
 'chr': 2929,
 'ord': 7327,
 'head': 4767,
 's_temp': 8544,
 'dig_0_index': 3594,
 'dig_1_index': 3595,
 'dig_2_index': 3596,
 '_input': 1588,
 'wa': 9980,
 'ac': 1841,
 'str_l': 9049,
 'int_l': 5210,
 'pp': 7680,
 'seikai': 8665,
 'matigai': 6186,
 'correct': 3181,
 'mistake': 6473,
 'io': 5259,
 'stringio': 9063,
 'kuku': 5627,
 's_set': 8541,
 'ansl': 2068,
 'mat_sum': 6177,
 'xrange': 10224,
 'color': 3056,
 'dfs': 3540,
 'mydict': 6640,
 'answer_list': 2077,
 'wd': 10011,
 'ck': 2946,
 'pe': 7513,
 'penalty': 7521,
 'cp': 3235,
 'r_map': 8032,
 'r_list': 8031,
 '最大公約数': 11957,
 '最小公倍数': 11965,
 'gcd_num': 4452,
 'lcm_num': 5762,
 'die': 3583,
 'pips': 7550,
 'move_die': 6569,
 'direction': 3617,
 'get_upside': 4530,
 'init_die': 5093,
 'pip': 7549,
 'roll_die': 8333,
 'directions': 3618,
 'maxs': 6287,
 'mins': 6447,
 'offset': 7257,
 '1000000007': 236,
 'matrix': 6189,
 'cv': 3342,
 'fv': 4399,
 'simu': 8784,
 'pro': 7809,
 'end_time': 3942,
 'bool': 2569,
 'name': 6724,
 'inds': 5078,
 'ds': 3779,
 'xy': 10236,
 'この2行でメモリアクセス省略しないとtleになる': 10573,
 'nds': 6770,
 'c1': 2717,
 'c2': 2718,
 '164': 452,
 '1415926535898': 403,
 '08': 104,
 'point': 7603,
 'sharp': 8735,
 'enumerate_divisors': 3970,
 'all_divisors': 1972,
 'divisor': 3685,
 'calculate_reminder': 2779,
 'reminder': 8174,
 'sorted_lst': 8880,
 'qs': 7951,
 'rs': 8414,
 'sect': 8640,
 'lstrip': 6040,
 '水たまりを結合': 12110,
 'sum_l': 9136,
 'room': 8349,
 '_s': 1611,
 'bitsum': 2507,
 '_bit': 1567,
 'bitadd': 2501,
 'al': 1940,
 'al_to_idx': 1942,
 'init': 5091,
 'n_': 6678,
 'bit': 2494,
 'idx': 4946,
 '_query': 1607,
 'decrement': 3482,
 'old': 7277,
 'increment': 5052,
 '_ans': 1564,
 'money': 6546,
 'inputs': 5175,
 'ss': 8944,
 'deg': 3490,
 '30': 850,
 '180': 483,
 '360': 908,
 'radians': 8049,
 'cos': 3182,
 'sin': 8785,
 '110000': 312,
 'kk': 5595,
 'lu': 6049,
 '1e18': 537,
 '2019': 660,
 '相対速度': 12168,
 '距離': 12384,
 'eval': 4011,
 'train': 9638,
 '回数': 11669,
 'nlogn': 6919,
 '時間': 11919,
 'nlog': 6918,
 'maxk': 6280,
 'kaisuu': 5551,
 'get_theta': 4523,
 'm_angle': 6092,
 'h_angle': 4696,
 'calculate_vector_distance': 2781,
 'theta': 9451,
 'dictionary': 3579,
 'input_num': 5160,
 'lim': 5849,
 '200004': 645,
 'bin_sum': 2465,
 'bin_sum2': 2466,
 'pop_num': 7627,
 '200005': 646,
 '整数': 11865,
 '整数複数個': 11871,
 '改行区切り': 11846,
 'スペース区切り': 11092,
 'の行列': 10869,
 'abc': 1764,
 'table': 9285,
 'have': 4753,
 'check_p': 2902,
 'ws': 10108,
 'ct': 3268,
 'nt': 7028,
 'can_eat': 2791,
 'f_time': 4112,
 'training': 9640,
 'high': 4798,
 'low': 6003,
 'casefold': 2834,
 'end_of_text': 3940,
 'syo': 9218,
 'amari': 2011,
 'dic': 3553,
 'kaisu': 5550,
 'long': 5969,
 'ae': 1903,
 'bs_meguru': 2615,
 'isok': 5345,
 'to': 9563,
 '59': 1112,
 'shell': 8742,
 'leng': 5803,
 'lists': 5910,
 'eg': 3893,
 'xs': 10225,
 'seen': 8652,
 'init_cmb': 5092,
 'nmax': 6925,
 '出力の制限': 11478,
 'g1': 4410,
 '元テーブル': 11397,
 'g2': 4411,
 '逆元テーブル': 12404,
 'inverse': 5247,
 '逆元テーブル計算用テーブル': 12405,
 'cmb': 2972,
 'modn': 6527,
 'ci': 2932,
 'である個数が': 10683,
 'となるような数列の数は': 10735,
 'ncm': 6752,
 'hm': 4820,
 '1cn': 531,
 'で足算する': 10707,
 'wk': 10059,
 'hw': 4864,
 'cnt_h': 3008,
 'cnt_w': 3015,
 'max_h': 6226,
 'max_w': 6262,
 'h_list': 4704,
 'w_list': 9972,
 'find_primes': 4240,
 'rn': 8324,
 'prev': 7720,
 'pos': 7639,
 'alp': 1989,
 'atoi': 2218,
 'insort_left': 5191,
 'tle': 9491,
 'dtype': 3789,
 'listをsortする': 5915,
 'a0cen': 1636,
 'b0cen': 2284,
 'nn': 6932,
 'a0cen1': 1637,
 'b0cen1': 2285,
 'ei': 3901,
 'mx': 6610,
 'su': 9077,
 'graph': 4627,
 'numofedges': 7114,
 'visited': 9908,
 'edges': 3883,
 'col': 3041,
 'adj': 1889,
 'maxcolor': 6270,
 'ans_1': 2040,
 'ans_2': 2041,
 'amax': 2014,
 'n0': 6661,
 '1以上となる最小の2のべき乗数': 598,
 'afre': 1907,
 'パワーの頻度': 11130,
 'は切り捨てなので': 10892,
 'rintで四捨五入してから': 8298,
 'rint': 8297,
 'fft': 4191,
 'irfft': 5278,
 'rfft': 8267,
 'scum': 8618,
 'cumsum': 3290,
 '累積和': 12229,
 'bd': 2393,
 '上からm個を取り出したい': 11223,
 'searchsorted': 8629,
 '価値iを生み出せる組みがm個以上ある': 11367,
 '価値iが生み出せる選び方の余分なものを引きたい': 11366,
 'ret': 8235,
 'numberofcards': 7103,
 'far': 4162,
 'kyu': 5643,
 'dist': 3638,
 'vec': 9890,
 'morau': 6551,
 'factinv': 4136,
 'solver': 8863,
 'bombs': 2559,
 'maxx': 6299,
 'maxy': 6300,
 'gcd1': 4446,
 'cmath': 2971,
 'inp': 5125,
 'nm': 6923,
 'heapify': 4773,
 'heappop': 4775,
 'heappush': 4776,
 '17': 464,
 'day': 3432,
 'data': 3419,
 '全探索なら': 11448,
 '4000': 957,
 'bit全探索でok': 2512,
 '一文字ずつlistへ格納': 11206,
 'most': 6553,
 '縦の全loop': 12261,
 'aa': 1740,
 'where': 10032,
 'b1': 2286,
 'b2': 2288,
 'b3': 2289,
 'alphabet': 1996,
 'loop': 5972,
 'i1': 4893,
 'i2': 4895,
 '96': 1430,
 '1100': 311,
 'get_dist': 4485,
 'du': 3791,
 'dv': 3799,
 'fullmatch': 4390,
 'exame': 4030,
 'suma': 9165,
 '連想配列': 12413,
 '先頭からの番号': 11405,
 '余分な量': 11352,
 'que': 7965,
 'k番目以降は一番左のやつ消していく': 5677,
 'cur': 3295,
 'examf': 4031,
 'si': 8765,
 'fact_inv': 4131,
 'getdivisor': 4539,
 'sum_leaf': 9137,
 'before_top': 2410,
 '29': 754,
 'iim': 4980,
 'p25': 7397,
 'p26': 7398,
 'p26inv': 7399,
 '576923081': 1106,
 'elem': 3919,
 'dq': 3770,
 'order': 7330,
 'deletelast': 3501,
 'koch': 5609,
 'start': 8972,
 '途中の頂点をa': 12408,
 'cとする': 3366,
 'rr': 8410,
 'segmenttree': 8662,
 '非再帰': 12491,
 'segment': 8661,
 'tree': 9646,
 'func': 4392,
 '配列の長さ': 12448,
 'minだとrmqになる': 6463,
 '木の高さhとすると': 12002,
 '1までのノード数': 587,
 'h段目のノードにアクセスするために使う': 4871,
 'ノード': 11119,
 'parent': 7475,
 'child': 2914,
 '1とk': 567,
 'bit_length': 2498,
 'あたいの初期化': 10470,
 'build': 2641,
 'setの後に一斉更新': 8727,
 'reversed': 8262,
 'update': 9791,
 'aに更新する': 2267,
 '更新ぶんをrootまで更新': 11927,
 'のfuncを求める': 10818,
 'queries': 7970,
 'a2n': 1646,
 'createinp': 3250,
 'seg': 8658,
 'terms': 9396,
 '51': 1073,
 'goukei': 4616,
 'route': 8386,
 'obs': 7238,
 'length': 5804,
 'alpha2num': 1995,
 'alpha': 1994,
 'item': 5366,
 'num2alpha': 7044,
 '64': 1169,
 '90': 1397,
 'ap': 2098,
 'bust': 2664,
 'win': 10047,
 'graph_input': 4630,
 'friend': 4358,
 'group': 4648,
 'で頂点': 10709,
 'がどの': 10507,
 'に属するかを記録していく': 10809,
 '後に': 11782,
 'に対して': 10807,
 ...}
tem.transform([train['code1'][0]]).toarray()
array([[0, 0, 0, ..., 0, 0, 0]])
tem.transform([train['code1'][0]]).toarray().shape
(1, 12552)

code1에 있는 단어 중 단어 집합에 해당하는 단어가 있는 경우에만 1을 출력하는 백터로 변환합니다.

cosine_similarity(tem.transform([train['code1'][0]]), tem.transform([train['code2'][0]]))
array([[0.32871913]])
train['similar'][0]
1

cosine_similarity 함수는 위에서 원-핫 인코딩 형태로 변환 된 단어들을 보고 유사성 여부를 판단합니다.

0.32로 유사성이 일부 있는 것으로 보이는데 실제 두 코드는 유사한 코드 입니다.

https://wikidocs.net/24603

class BaselineModel():
    def __init__(self, threshold = 0.5):
        super(BaselineModel, self).__init__()
        self.threshold = threshold
        self.vectorizer = CountVectorizer()

    def fit(self, code1, code2):
        self.vectorizer.fit(code1)
        self.vectorizer.fit(code2)
        print('Done.')
    
    def predict_proba(self, code1, code2):
        code1_vecs = self.vectorizer.transform(code1)
        code2_vecs = self.vectorizer.transform(code2)

        preds = []

        for code1_vec, code2_vec in zip(code1_vecs, code2_vecs):
            preds.append(cosine_similarity(code1_vec, code2_vec))
        
        preds = np.reshape(preds, len(preds))
        print('Done.')

        return preds
    
    @discord_sender(webhook_url="https://discordapp.com/api/webhooks/9810o3fUYfVz2jWg7if")
    def predict(self, code1, code2):
        preds = self.predict_proba(code1, code2)
        preds = np.where(preds > self.threshold, 1, 0)

        return preds

트레인 데이터를 이용해 단어 집합을 만들고 테스트 데이터를 원-핫 인코딩 방식으로 변환합니다.

다음으로 cosine_similarity 함수를 사용해 원-핫 인코딩 벡터의 유사성을 검정해 임개값보다 크면 1, 작으면 0을 출력합니다.

딥러닝이라기 보다 단순한 컴퓨터 노가다에 가깝죠. 베이스라인으로 사용하기에 좋은 모델인 것 같습니다.

참고로 함수 위에 @discord_sender(webhook_url = '디스코드 웹 서버')를 입력하면 함수 시작하는 시점, 끝나는 시점이 디스코드 알림으로 옵니다.

학습하는데 오래걸리는 딥러닝 모델에 경우, 핸드폰으로 알림을 받을 수 있으니 알아두면 정말 좋은 기능이 될 것 같아요.

model = BaselineModel(threshold = 0.4)
model.fit(train['code1'], train['code2'])
Done.
preds = model.predict(test['code1'], test['code2'])
Done.
(train['similar']).mean()
0.5011129660545354
(preds).mean()
0.5092877017250974

threshold = 0.5일때 약 0.3, threshold = 0.3일때 약 0.7, threshold = 0.4일때 약 0.5를 가지는 것을 확인했습니다.

실제 트레인 데이터의 유사성이 1일 확률이 0.5에 가까움으로 임계값은 0.4로 잡겠습니다.

(대회는 ACCURACY 기준)

predtrain = model.predict(train['code1'], train['code2'])
(predtrain == train['similar']).mean()
Done.
0.7125765164162493

임계값 0.4를 사용해 트레인 데이터를 예측에 사용하면 약 0.71에 정확성을 보입니다.

사실 거창한 딥러닝을 사용하지 않아도 쓰는 단어가 얼추 비슷하면 유사성 판단은 어느정도는 하는 것을 알 수 있습니다.

sample_submission['similar'] = preds
sample_submission.to_csv('dacon_codes.csv', index = False)
# 결과 : 0.688

huggingface 툴 사용

!pip install transformers
!pip install transformers datasets
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.19.2)
Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.7.0)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.12.1)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.3)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.19.2)
Collecting datasets
  Downloading datasets-2.2.2-py3-none-any.whl (346 kB)
     |████████████████████████████████| 346 kB 4.3 MB/s 
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.3)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)
Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.7.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.0)
Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.12.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.7.0)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (4.2.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)
Collecting xxhash
  Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
     |████████████████████████████████| 212 kB 10.1 MB/s 
Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
     |████████████████████████████████| 140 kB 10.9 MB/s 
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
     |████████████████████████████████| 1.1 MB 12.0 MB/s 
Collecting dill<0.3.5
  Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)
     |████████████████████████████████| 86 kB 6.1 MB/s 
Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.12.2)
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.5.18.1)
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1
  Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)
     |████████████████████████████████| 127 kB 28.0 MB/s 
Collecting aiosignal>=1.1.2
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting asynctest==0.13.0
  Downloading asynctest-0.13.0-py3-none-any.whl (26 kB)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (21.4.0)
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)
     |████████████████████████████████| 271 kB 28.3 MB/s 
Collecting async-timeout<5.0,>=4.0.0a3
  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Collecting multidict<7.0,>=4.5
  Downloading multidict-6.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (94 kB)
     |████████████████████████████████| 94 kB 1.8 MB/s 
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)
     |████████████████████████████████| 144 kB 29.7 MB/s 
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.0.12)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.8.0)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)
Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.1)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)
Installing collected packages: multidict, frozenlist, yarl, urllib3, asynctest, async-timeout, aiosignal, fsspec, dill, aiohttp, xxhash, responses, datasets
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.24.3
    Uninstalling urllib3-1.24.3:
      Successfully uninstalled urllib3-1.24.3
  Attempting uninstall: dill
    Found existing installation: dill 0.3.5.1
    Uninstalling dill-0.3.5.1:
      Successfully uninstalled dill-0.3.5.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires tornado~=5.1.0; python_version >= "3.0", but you have tornado 6.1 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
Successfully installed aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 datasets-2.2.2 dill-0.3.4 frozenlist-1.3.0 fsspec-2022.5.0 multidict-6.0.2 responses-0.18.0 urllib3-1.25.11 xxhash-3.0.0 yarl-1.7.2
from transformers import AutoTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset, load_metric
import torch

#device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = "klue/bert-base" # 'microsoft/graphcodebert-base'
MAX_LEN = 256
dataset = load_dataset('csv', data_files = path+'sample_train.csv')['train']
tokenizer = AutoTokenizer.from_pretrained(model)
Using custom data configuration default-e6c40baceae51225
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-e6c40baceae51225/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)
https://huggingface.co/klue/bert-base/resolve/main/tokenizer_config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpjiwfrzjb
storing https://huggingface.co/klue/bert-base/resolve/main/tokenizer_config.json in cache at /root/.cache/huggingface/transformers/f8f71eb411bb03f57b455cfb1b4e04ae124201312e67a3ad66e0a92d0c228325.78871951edcb66032caa0a9628d77b3557c23616c653dacdb7a1a8f33011a843
creating metadata file for /root/.cache/huggingface/transformers/f8f71eb411bb03f57b455cfb1b4e04ae124201312e67a3ad66e0a92d0c228325.78871951edcb66032caa0a9628d77b3557c23616c653dacdb7a1a8f33011a843
https://huggingface.co/klue/bert-base/resolve/main/config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp3hhezuli
storing https://huggingface.co/klue/bert-base/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
creating metadata file for /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
loading configuration file https://huggingface.co/klue/bert-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
Model config BertConfig {
  "_name_or_path": "klue/bert-base",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.19.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 32000
}

https://huggingface.co/klue/bert-base/resolve/main/vocab.txt not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpf76t5zqa
storing https://huggingface.co/klue/bert-base/resolve/main/vocab.txt in cache at /root/.cache/huggingface/transformers/1a36e69d48a008e522b75e43693002ffc8b6e6df72de7c53412c23466ec165eb.085110015ec67fc02ad067f712a7c83aafefaf31586a3361dd800bcac635b456
creating metadata file for /root/.cache/huggingface/transformers/1a36e69d48a008e522b75e43693002ffc8b6e6df72de7c53412c23466ec165eb.085110015ec67fc02ad067f712a7c83aafefaf31586a3361dd800bcac635b456
https://huggingface.co/klue/bert-base/resolve/main/tokenizer.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp5jkm4tpg
storing https://huggingface.co/klue/bert-base/resolve/main/tokenizer.json in cache at /root/.cache/huggingface/transformers/310a974e892b181d75eed58b545cc0592d066ae4ef35cc760ea92e9b0bf65b3b.74f7933572f937b11a02b2cfb4e88a024059be36c84f53241b85b1fec49e21f7
creating metadata file for /root/.cache/huggingface/transformers/310a974e892b181d75eed58b545cc0592d066ae4ef35cc760ea92e9b0bf65b3b.74f7933572f937b11a02b2cfb4e88a024059be36c84f53241b85b1fec49e21f7
https://huggingface.co/klue/bert-base/resolve/main/special_tokens_map.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmp34fn1p8b
storing https://huggingface.co/klue/bert-base/resolve/main/special_tokens_map.json in cache at /root/.cache/huggingface/transformers/aeaaa3afd086a040be912f92ffe7b5f85008b744624f4517c4216bcc32b51cf0.054ece8d16bd524c8a00f0e8a976c00d5de22a755ffb79e353ee2954d9289e26
creating metadata file for /root/.cache/huggingface/transformers/aeaaa3afd086a040be912f92ffe7b5f85008b744624f4517c4216bcc32b51cf0.054ece8d16bd524c8a00f0e8a976c00d5de22a755ffb79e353ee2954d9289e26
loading file https://huggingface.co/klue/bert-base/resolve/main/vocab.txt from cache at /root/.cache/huggingface/transformers/1a36e69d48a008e522b75e43693002ffc8b6e6df72de7c53412c23466ec165eb.085110015ec67fc02ad067f712a7c83aafefaf31586a3361dd800bcac635b456
loading file https://huggingface.co/klue/bert-base/resolve/main/tokenizer.json from cache at /root/.cache/huggingface/transformers/310a974e892b181d75eed58b545cc0592d066ae4ef35cc760ea92e9b0bf65b3b.74f7933572f937b11a02b2cfb4e88a024059be36c84f53241b85b1fec49e21f7
loading file https://huggingface.co/klue/bert-base/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/klue/bert-base/resolve/main/special_tokens_map.json from cache at /root/.cache/huggingface/transformers/aeaaa3afd086a040be912f92ffe7b5f85008b744624f4517c4216bcc32b51cf0.054ece8d16bd524c8a00f0e8a976c00d5de22a755ffb79e353ee2954d9289e26
loading file https://huggingface.co/klue/bert-base/resolve/main/tokenizer_config.json from cache at /root/.cache/huggingface/transformers/f8f71eb411bb03f57b455cfb1b4e04ae124201312e67a3ad66e0a92d0c228325.78871951edcb66032caa0a9628d77b3557c23616c653dacdb7a1a8f33011a843
loading configuration file https://huggingface.co/klue/bert-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
Model config BertConfig {
  "_name_or_path": "klue/bert-base",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.19.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 32000
}

load_dataset 함수 : csv 파일을 가공하기 쉬운 데이터 셋으로 만들어줍니다.

AutoTokenizer내 from_pretrained 함수에서 프리트레인 모델 이름만 입력하면 자동으로 토크나이징이 됩니다.

def example_fn(examples):
    outputs = tokenizer(examples['code1'], examples['code2'], padding = True, max_length = MAX_LEN, truncation = True)

    if 'similar' in examples:
        outputs['labels'] = examples['similar']
    
    return outputs


dataset = dataset.map(example_fn, remove_columns = ['code1', 'code2', 'similar'])

dataset = dataset.train_test_split(0.1)
dataset
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 16173
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1797
    })
})

dataset 내 map 함수 : 원소별로 입력된 함수를 적용합니다.

dataset 내 train_test_split 함수 : 트레인-테스트 데이터 셋으로 분할한 딕셔너리를 만듭니다.

_collator = DataCollatorWithPadding(tokenizer = tokenizer) # 아래 사진으로 해당함수 설명
_metric = load_metric('glue', 'sst2') # 측정함수도 huggingface 내 존재
# https://huggingface.co/docs/datasets/v1.0.1/loading_metrics.html 참고문서

def metric_fn(p): # 측정함수
    preds, labels = p
    output = _metric.compute(references = labels, predictions = np.argmax(preds, axis = -1))
    return output

model = BertForSequenceClassification.from_pretrained(model)

args = TrainingArguments(
    'runs/',
    per_device_train_batch_size = 32,
    num_train_epochs = 3,
    do_train = True,
    do_eval = True,
    save_strategy = 'epoch',
    logging_strategy = 'epoch',
    evaluation_strategy = 'epoch',
)
loading configuration file https://huggingface.co/klue/bert-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/fbd0b2ef898c4653902683fea8cc0dd99bf43f0e082645b913cda3b92429d1bb.99b3298ed554f2ad731c27cdb11a6215f39b90bc845ff5ce709bb4e74ba45621
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.19.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 32000
}

https://huggingface.co/klue/bert-base/resolve/main/pytorch_model.bin not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpb8lhvjfk
storing https://huggingface.co/klue/bert-base/resolve/main/pytorch_model.bin in cache at /root/.cache/huggingface/transformers/05b36ee62545d769939a7746eca739b844a40a7a7553700f110b58b28ed6a949.7cb231256a5dbe886e12b902d05cb1241f330d8c19428508f91b2b28c1cfe0b6
creating metadata file for /root/.cache/huggingface/transformers/05b36ee62545d769939a7746eca739b844a40a7a7553700f110b58b28ed6a949.7cb231256a5dbe886e12b902d05cb1241f330d8c19428508f91b2b28c1cfe0b6
loading weights file https://huggingface.co/klue/bert-base/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/05b36ee62545d769939a7746eca739b844a40a7a7553700f110b58b28ed6a949.7cb231256a5dbe886e12b902d05cb1241f330d8c19428508f91b2b28c1cfe0b6
Some weights of the model checkpoint at klue/bert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at klue/bert-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).

image.png

trainer = Trainer(
    model = model,
    args = args,
    data_collator = _collator,
    train_dataset = dataset['train'],
    eval_dataset = dataset['test'],
    tokenizer = tokenizer,
    compute_metrics = metric_fn
)

@discord_sender(webhook_url="https://discordapp.com/api/webhooks/98101yo3fUYfVz2jWg7if")
def tem():
    trainer.train()

tem()
***** Running training *****
  Num examples = 16173
  Num Epochs = 3
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 1518
[1518/1518 39:04, Epoch 3/3]
Epoch Training Loss Validation Loss Accuracy
1 0.304100 0.188209 0.920423
2 0.117700 0.130098 0.960490
3 0.036200 0.093612 0.978297

</div> </div>

***** Running Evaluation *****
  Num examples = 1797
  Batch size = 8
Saving model checkpoint to runs/checkpoint-506
Configuration saved in runs/checkpoint-506/config.json
Model weights saved in runs/checkpoint-506/pytorch_model.bin
tokenizer config file saved in runs/checkpoint-506/tokenizer_config.json
Special tokens file saved in runs/checkpoint-506/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 1797
  Batch size = 8
Saving model checkpoint to runs/checkpoint-1012
Configuration saved in runs/checkpoint-1012/config.json
Model weights saved in runs/checkpoint-1012/pytorch_model.bin
tokenizer config file saved in runs/checkpoint-1012/tokenizer_config.json
Special tokens file saved in runs/checkpoint-1012/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 1797
  Batch size = 8
Saving model checkpoint to runs/checkpoint-1518
Configuration saved in runs/checkpoint-1518/config.json
Model weights saved in runs/checkpoint-1518/pytorch_model.bin
tokenizer config file saved in runs/checkpoint-1518/tokenizer_config.json
Special tokens file saved in runs/checkpoint-1518/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)


</div> </div> </div>
test_dataset = load_dataset('csv', data_files = path+'test.csv')['train']
test_dataset = test_dataset.map(example_fn, remove_columns = ['code1', 'code2'])

predictions = trainer.predict(test_dataset)

sample_submission['similar'] = np.argmax(predictions.predictions, axis = -1)
sample_submission.to_csv('dacon_codes2.csv', index = False)
# 결과 : 0.787
Using custom data configuration default-6692cc772abf77e3
Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-6692cc772abf77e3/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)
The following columns in the test set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: pair_id. If pair_id are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 179700
  Batch size = 8
[22463/22463 49:25]
predictions.predictions
array([[ 4.3174033, -3.926483 ],
       [-4.4421177,  4.0519753],
       [-4.2312655,  3.7239847],
       ...,
       [ 2.8522801, -2.7274592],
       [-4.4676304,  4.039054 ],
       [ 2.7921119, -2.535519 ]], dtype=float32)
sample_submission['similar'].mean()
0.6730996104618809
</div>