忘れないようにメモっとく

機械学習とかプログラミングとか。

PyMC3で階層ベイズモデル - シックスネーションズ

この記事は、PyMC3のドキュメント A Hierarchical model for Rugby prediction — PyMC3 3.0 documentation のコードを引用しています。

PyMC3でラグビーの階層ベイズモデリング

元記事では、シックスネーションズ2014の各試合を階層ベイズモデリングしている。元記事自体も、プレミアリーグモデリングをした記事を元ネタにしているらしい。 A Hierarchical Bayesian Model of the Premier League

シックスネーションズ

シックスネーションズ、欧州チャンピオンを決める大会。2015年ラグビーワールドカップでジャパンがスプリングボクスから大金星を挙げて、ブームと呼べるほどの注目を集めたラグビーだけど、「シックスネーションズ」の知名度はまだまだ感がある。 でもでもおそらく、クラブ、インターナショナル、数あるフォーマットの中でも、最も歴史と伝統ある大会で、実力的にも北半球最強決定戦と言って差し支えないでしょう。 開幕戦は2/4, Scotland vs Ireland。矢野武さん風に言うと、「極上の5週間が始まります!」って感じ。

さて、元記事でのシックスネーションズの説明は、

名前の通り、欧州6カ国対抗戦。元々は、ホームユニオンと呼ばれるイングランドウェールズアイルランドスコットランドで行われていた大会。初開催は1883年。フランスが1910年に参加、2000年にイタリアが参加し、現在のフォーマットになった。ちなみにポール・オコンネルは、2014年に優勝したアイルランドのキャプテン。アイリッシュ魂を象徴するような、激しく気迫のこもったプレーが持ち味のリアルロック。全盛期は一個小隊に匹敵する戦力との噂も。 http://i.telegraph.co.uk/multimedia/archive/03571/paul_oconnell_3571133b.jpg

さて本記事は、2016年のデータを利用するわけだから、前年王者イングランドのキャプテン、Dylan Hartley(ディラン・ハートリー)を追記しよう。2015年ラグビーワールドカップ、開催地として上位進出が期待されていたものの、オーストラリア、ウェールズと同組のいわゆる「死のプール」で2敗。ワールドカップ史上初の開催国予選敗退を経験した失意のイングランドで、キャプテンに指名されたのは、ディラン・ハートリーだった。暴言や危険なプレーを繰り返し、「悪童」と呼ばれた巨漢HOのキャプテン就任は、ラグビー界にとって大きなサプライズ。指名したのは、元日本代表監督エディー・ジョーンズだった。 2016年はグランドスラム(全勝優勝)で大会を終え、その後のテストマッチでも負け無しで14連勝中のイングランド。今年のシックスネーションズでは、優勝候補の1番手だが、同時に世界新記録のテストマッチ19連勝も見えている。

http://i3.mirror.co.uk/incoming/article4633256.ece/ALTERNATES/s615/Dylan-Hartley.jpg

モデル化

https://pymc-devs.github.io/pymc3/notebooks/rugby_analytics.html#The-model.

前置きが長くなってしまった。 モデルの説明をざっくり書くと、

  • 得点をモデル化(ポアソン分布)
  • アタッキングパラメータとディフェンスパラメータを各チームごとに定義
  • 本拠地での試合かどうかも考慮
 log \theta_{g1} = home + att_{h(g)} + def_{a(g)}
 log \theta_{g2} = att_{a(g)} + def_{h(g)}

上式をみると、自チームのattと敵チームのdefの和が大きいと、得点も大きくなる。 attが大きいと得点能力が高く、defが大きいと相手にたくさん得点を与えてしまうことを意味している。 モデルの説明とかは、コードにコメントしておく。詳細知りたかったら元記事読んでね。 元データは、wikipediaより。 https://en.wikipedia.org/wiki/2016_Six_Nations_Championship

import numpy as np
import pandas as pd
try:
    from StringIO import StringIO
except ImportError:
    from io import StringIO
%matplotlib inline
import pymc3 as pm, theano.tensor as tt

# 2016 Six Nations
data_csv = StringIO("""home_team,away_team,home_score,away_score
France,Italy,23,21
Scotland,England,9,15
Ireland,Wales,16,16
France,Ireland,10,9
Wales,Scotland,27,23
Italy,England,9,40
Wales,France,19,10
Italy,Scotland,20,36
England,Ireland,21,10
Ireland,Italy,58,15
England,Wales,25,21
Scotland,France,29,18
Wales,Italy,67,14
Ireland,Scotland,35,25
France,England,21,31""")


# データ読み込み
df = pd.read_csv(data_csv)

teams = df.home_team.unique()
teams = pd.DataFrame(teams, columns=['team'])
teams['i'] = teams.index

df = pd.merge(df, teams, left_on='home_team', right_on='team', how='left')
df = df.rename(columns = {'i': 'i_home'}).drop('team', 1)
df = pd.merge(df, teams, left_on='away_team', right_on='team', how='left')
df = df.rename(columns = {'i': 'i_away'}).drop('team', 1)

observed_home_goals = df.home_score.values
observed_away_goals = df.away_score.values

home_team = df.i_home.values
away_team = df.i_away.values

num_teams = len(df.i_home.drop_duplicates())
num_games = len(home_team)

g = df.groupby('i_away')
att_starting_points = np.log(g.away_score.mean())
g = df.groupby('i_home')
def_starting_points = -np.log(g.away_score.mean())

model = pm.Model()
with pm.Model() as model:
    # global model parameters
    # 無情報事前分布
    home = pm.Normal('home', 0, tau=.0001)
    tau_att = pm.Gamma('tau_att', .1, .1)
    tau_def = pm.Gamma('tau_def', .1, .1)
    intercept = pm.Normal('intercept', 0, tau=.0001)

    # team-specific model parameters
    atts_star   = pm.Normal("atts_star",
                           mu=0,
                           tau=tau_att,
                           shape=num_teams)
    defs_star   = pm.Normal("defs_star",
                           mu=0,
                           tau=tau_def,
                           shape=num_teams)
    # チームごとアタック
    atts = pm.Deterministic('atts', atts_star - tt.mean(atts_star))
    # チームごとディフェンス
    defs = pm.Deterministic('defs', defs_star - tt.mean(defs_star))
    # ホームのθ(得点)
    home_theta  = tt.exp(intercept + home + atts[home_team] + defs[away_team])
    # アウェイのθ(得点)
    away_theta  = tt.exp(intercept + atts[away_team] + defs[home_team])

    # likelihood of observed data
    home_points = pm.Poisson('home_points', mu=home_theta, observed=observed_home_goals)
    away_points = pm.Poisson('away_points', mu=away_theta, observed=observed_away_goals)

with model:
    # 初期値
    start = pm.find_MAP()
    step = pm.NUTS(state=start)
    # サンプリング
    trace = pm.sample(2000, step, init=start)

# グラフ描画
pm.traceplot(trace)
# 最初のほう捨てたかったら
# pm.traceplot(trace[500:])

サンプリング結果は、traceに保持されている。 f:id:Akiniwa:20170204223048p:plain

atts、defsの各パラメータについて。まず、attsについて、これが大きいほど、得点能力が高いといえるだろう。逆に、defsは小さいほうが失点が少ないチームということを表現している。際立つイングランドのディフェンス。(イタリアぁ…)

pm.forestplot(trace, varnames=['atts'], ylabels=['France','Scotland','Ireland','Wales','Italy','England'], main="Team Offense")

f:id:Akiniwa:20170205174108p:plain

pm.forestplot(trace, varnames=['defs'], ylabels=['France','Scotland','Ireland','Wales','Italy','England'], main="Team Deffence")

f:id:Akiniwa:20170205174120p:plain

ちなみに、各パラメータは、varnamesとget_valuesで取得できる。

for name in trace.varnames:
    values = trace.get_values(name)[500:]
    if name in ('atts', 'defs'):
        for i, team_values in enumerate(values.T):
            print(name, teams.ix[i].values[0], np.median(team_values))
    else:
        print(name, np.median(values))

"""
>>>
home 0.197413525093
tau_att_log_ 2.36765017677
tau_def_log_ 1.62635825121
intercept 2.94638413619
atts_star 0.0243648852216
defs_star -0.0579292662131
tau_att 10.6722849008
tau_def 5.08532149318
atts France -0.316409735727
atts Scotland 0.102634304541
atts Ireland 0.0401948177547
atts Wales 0.188146988173
atts Italy -0.128166189509
atts England 0.11534176357
defs France -0.0267028885955
defs Scotland 0.0636352715494
defs Ireland -0.170924989081
defs Wales -0.139824762146
defs Italy 0.679569506375
defs England -0.391210743404
"""

パラメータの中央値を使った2017シックスネーションズの推定値↓ (注: あくまで2016のデータでモデリングした結果、こういう推定値が出せますというだけです。絶対にこのスコア通りになるというわけでは全然ないです。念のため。)

Scotland 21-21 Ireland
England 25-9 France
Italy 17-45 Wales
Italy 17-39 Ireland
Wales 18-18 England
France 18-20 Scotland
Scotland 22-24 Wales
Ireland 23-11 France
England 51-11 Italy
Wales 23-17 Ireland
Italy 19-27 France
England 27-14 Scotland
Scotland 50-17 Italy
France 14-22 Wales
Ireland 16-18 England