PyMC3で階層ベイズモデル - シックスネーションズ
この記事は、PyMC3のドキュメント A Hierarchical model for Rugby prediction — PyMC3 3.0 documentation のコードを引用しています。
PyMC3でラグビーの階層ベイズモデリング
元記事では、シックスネーションズ2014の各試合を階層ベイズでモデリングしている。元記事自体も、プレミアリーグのモデリングをした記事を元ネタにしているらしい。 A Hierarchical Bayesian Model of the Premier League
シックスネーションズ
シックスネーションズ、欧州チャンピオンを決める大会。2015年ラグビーワールドカップでジャパンがスプリングボクスから大金星を挙げて、ブームと呼べるほどの注目を集めたラグビーだけど、「シックスネーションズ」の知名度はまだまだ感がある。 でもでもおそらく、クラブ、インターナショナル、数あるフォーマットの中でも、最も歴史と伝統ある大会で、実力的にも北半球最強決定戦と言って差し支えないでしょう。 開幕戦は2/4, Scotland vs Ireland。矢野武さん風に言うと、「極上の5週間が始まります!」って感じ。
さて、元記事でのシックスネーションズの説明は、
- Six Nations consists of Italy, Ireland, Scotland, England, France and Wales
- Game consists of scoring tries (similar to touch downs) or kicking the goal.
- Average player is something like 100kg and 1.82m tall.
Paul O’Connell the Irish captain is Height: 6’ 6” (1.98 m) Weight: 243 lbs (110 kg)
シックスネーションズの参加国は、イタリア、アイルランド、スコットランド、イングランド、フランス、ウェールズの6カ国
- 試合の得点は、トライかゴールキック
- プレイヤーの平均は1.82m、100kg
- ポール・オコンネルはアイルランドのキャプテンで、1.98m、110kg
名前の通り、欧州6カ国対抗戦。元々は、ホームユニオンと呼ばれるイングランド、ウェールズ、アイルランド、スコットランドで行われていた大会。初開催は1883年。フランスが1910年に参加、2000年にイタリアが参加し、現在のフォーマットになった。ちなみにポール・オコンネルは、2014年に優勝したアイルランドのキャプテン。アイリッシュ魂を象徴するような、激しく気迫のこもったプレーが持ち味のリアルロック。全盛期は一個小隊に匹敵する戦力との噂も。
さて本記事は、2016年のデータを利用するわけだから、前年王者イングランドのキャプテン、Dylan Hartley(ディラン・ハートリー)を追記しよう。2015年ラグビーワールドカップ、開催地として上位進出が期待されていたものの、オーストラリア、ウェールズと同組のいわゆる「死のプール」で2敗。ワールドカップ史上初の開催国予選敗退を経験した失意のイングランドで、キャプテンに指名されたのは、ディラン・ハートリーだった。暴言や危険なプレーを繰り返し、「悪童」と呼ばれた巨漢HOのキャプテン就任は、ラグビー界にとって大きなサプライズ。指名したのは、元日本代表監督エディー・ジョーンズだった。 2016年はグランドスラム(全勝優勝)で大会を終え、その後のテストマッチでも負け無しで14連勝中のイングランド。今年のシックスネーションズでは、優勝候補の1番手だが、同時に世界新記録のテストマッチ19連勝も見えている。
モデル化
https://pymc-devs.github.io/pymc3/notebooks/rugby_analytics.html#The-model.
前置きが長くなってしまった。 モデルの説明をざっくり書くと、
- 得点をモデル化(ポアソン分布)
- アタッキングパラメータとディフェンスパラメータを各チームごとに定義
- 本拠地での試合かどうかも考慮
上式をみると、自チームのattと敵チームのdefの和が大きいと、得点も大きくなる。 attが大きいと得点能力が高く、defが大きいと相手にたくさん得点を与えてしまうことを意味している。 モデルの説明とかは、コードにコメントしておく。詳細知りたかったら元記事読んでね。 元データは、wikipediaより。 https://en.wikipedia.org/wiki/2016_Six_Nations_Championship
import numpy as np import pandas as pd try: from StringIO import StringIO except ImportError: from io import StringIO %matplotlib inline import pymc3 as pm, theano.tensor as tt # 2016 Six Nations data_csv = StringIO("""home_team,away_team,home_score,away_score France,Italy,23,21 Scotland,England,9,15 Ireland,Wales,16,16 France,Ireland,10,9 Wales,Scotland,27,23 Italy,England,9,40 Wales,France,19,10 Italy,Scotland,20,36 England,Ireland,21,10 Ireland,Italy,58,15 England,Wales,25,21 Scotland,France,29,18 Wales,Italy,67,14 Ireland,Scotland,35,25 France,England,21,31""") # データ読み込み df = pd.read_csv(data_csv) teams = df.home_team.unique() teams = pd.DataFrame(teams, columns=['team']) teams['i'] = teams.index df = pd.merge(df, teams, left_on='home_team', right_on='team', how='left') df = df.rename(columns = {'i': 'i_home'}).drop('team', 1) df = pd.merge(df, teams, left_on='away_team', right_on='team', how='left') df = df.rename(columns = {'i': 'i_away'}).drop('team', 1) observed_home_goals = df.home_score.values observed_away_goals = df.away_score.values home_team = df.i_home.values away_team = df.i_away.values num_teams = len(df.i_home.drop_duplicates()) num_games = len(home_team) g = df.groupby('i_away') att_starting_points = np.log(g.away_score.mean()) g = df.groupby('i_home') def_starting_points = -np.log(g.away_score.mean()) model = pm.Model() with pm.Model() as model: # global model parameters # 無情報事前分布 home = pm.Normal('home', 0, tau=.0001) tau_att = pm.Gamma('tau_att', .1, .1) tau_def = pm.Gamma('tau_def', .1, .1) intercept = pm.Normal('intercept', 0, tau=.0001) # team-specific model parameters atts_star = pm.Normal("atts_star", mu=0, tau=tau_att, shape=num_teams) defs_star = pm.Normal("defs_star", mu=0, tau=tau_def, shape=num_teams) # チームごとアタック atts = pm.Deterministic('atts', atts_star - tt.mean(atts_star)) # チームごとディフェンス defs = pm.Deterministic('defs', defs_star - tt.mean(defs_star)) # ホームのθ(得点) home_theta = tt.exp(intercept + home + atts[home_team] + defs[away_team]) # アウェイのθ(得点) away_theta = tt.exp(intercept + atts[away_team] + defs[home_team]) # likelihood of observed data home_points = pm.Poisson('home_points', mu=home_theta, observed=observed_home_goals) away_points = pm.Poisson('away_points', mu=away_theta, observed=observed_away_goals) with model: # 初期値 start = pm.find_MAP() step = pm.NUTS(state=start) # サンプリング trace = pm.sample(2000, step, init=start) # グラフ描画 pm.traceplot(trace) # 最初のほう捨てたかったら # pm.traceplot(trace[500:])
サンプリング結果は、traceに保持されている。
atts、defsの各パラメータについて。まず、attsについて、これが大きいほど、得点能力が高いといえるだろう。逆に、defsは小さいほうが失点が少ないチームということを表現している。際立つイングランドのディフェンス。(イタリアぁ…)
pm.forestplot(trace, varnames=['atts'], ylabels=['France','Scotland','Ireland','Wales','Italy','England'], main="Team Offense")
pm.forestplot(trace, varnames=['defs'], ylabels=['France','Scotland','Ireland','Wales','Italy','England'], main="Team Deffence")
ちなみに、各パラメータは、varnamesとget_valuesで取得できる。
for name in trace.varnames: values = trace.get_values(name)[500:] if name in ('atts', 'defs'): for i, team_values in enumerate(values.T): print(name, teams.ix[i].values[0], np.median(team_values)) else: print(name, np.median(values)) """ >>> home 0.197413525093 tau_att_log_ 2.36765017677 tau_def_log_ 1.62635825121 intercept 2.94638413619 atts_star 0.0243648852216 defs_star -0.0579292662131 tau_att 10.6722849008 tau_def 5.08532149318 atts France -0.316409735727 atts Scotland 0.102634304541 atts Ireland 0.0401948177547 atts Wales 0.188146988173 atts Italy -0.128166189509 atts England 0.11534176357 defs France -0.0267028885955 defs Scotland 0.0636352715494 defs Ireland -0.170924989081 defs Wales -0.139824762146 defs Italy 0.679569506375 defs England -0.391210743404 """
パラメータの中央値を使った2017シックスネーションズの推定値↓ (注: あくまで2016のデータでモデリングした結果、こういう推定値が出せますというだけです。絶対にこのスコア通りになるというわけでは全然ないです。念のため。)
Scotland 21-21 Ireland England 25-9 France Italy 17-45 Wales Italy 17-39 Ireland Wales 18-18 England France 18-20 Scotland Scotland 22-24 Wales Ireland 23-11 France England 51-11 Italy Wales 23-17 Ireland Italy 19-27 France England 27-14 Scotland Scotland 50-17 Italy France 14-22 Wales Ireland 16-18 England