ぴょこりんブログ

裏垢です。

ゆるふわに傾向スコアマッチングを理解してみる。

この記事は異世界行ったら本気だすぴょこりんクラスタ Advent Calendar 2021のためにかきました。

はじめに

最近流行りの因果推論ってやつをちょっとかじってみようかなと思った、という感じのアレです。ちょろっと本読んだけど少しやっぱ自分の手を動かして色々確認してみるのも意義があるかなぁということで傾向スコアマッチング、やってみます。

問題設定

効果測定の問題を取り扱います。例えば実数をなす2変数からなるサンプル(x1,x2)があったとします。。加えて、各々のサンプルに対して何らかの介入がなされるかどうか意味する介入変数z∈{0,1}が与えられているものとします。サンプル固有の変数と、そこに加わる介入の有無により、効果が計算できるとします。例えば、

y=100*x1+60*x2+20*z

とします。ここでzにかかる定数、介入による真の効果を20としましょう。

yの値と(x1,x2)、そしてzが既知、そして上述の数式を事前知識としてもたないときに、介入による真の効果を推定する問題を考えます。

このような問題設定の例えとして、x1,x2が顧客のプロファイル情報、zを営業施策の有無、yが支払額としたとき、営業施策の効果はどうだったのか?といった効果検証の問題になります。

素朴な解法

一番簡単なやり方として、z=0のときの平均とz=1のときの差分がその真の効果になるだろう、というのが一つのアイディアです。

z=0のデータが取れているサンプルと、z=1のデータが取れているサンプルが同じ分布から生成されているのであれば、これは良い近似になります。

以下のようにして正規分布からデータを生成してみましょう。

n_sample_a=1000
m_1=np.random.normal(0,1,2)
m_2=np.random.normal(0,1,2)
v_1=invwishart(2,np.array([0.001,0.001])).rvs()
v_2=invwishart(2,np.array([0.001,0.001])).rvs()
sample_a=multivariate_normal(m_1,v_1).rvs(n_sample_a)
sample_b=multivariate_normal(m_1,v_1).rvs(n_sample_a)

こんな感じのデータが得られます。一応z=0 (a群)のサンプルとz=0 (b群)のサンプルで色を分けています。f:id:cappsLk:20211203213143p:plain

これらに先ほどの式を適用して、平均の差を算出してみます。

y_a=100*sample_a[:,0]+60*sample_a[:,1] # a群
y_b=100*sample_b[:,0]+60*sample_b[:,1]+20 # b群

zは0か1の値を取るので、実装上はただ20を足すだけになっています。

この引き算をすれば、当然ほぼ20になるはずで、

print(np.mean(y_b)-np.mean(y_a)) 
# 20.083342615002607

ほぼ20です。ここまでは良いでしょう。

偏りのあるサンプル

先ほどの例は、Zの与え方とXの与え方が独立でした。グラフィカルモデルを書くと以下のような感じ。

じゃあそうじゃないケース、つまり、ZがXの与え方に影響を与えてしまう、以下のようなケース。

特定層、例えば高額課金顧客に向けてのみ施策を実施する、みたいなことをすると、こういうことが起きる。

このような場合に先ほどのような単純な方法が効かなくなります。(x1,x2)の振る舞いが異なるので、真の介入の効果以外の項の影響で平均値が異なってしまう。

実際に計算してみましょう。

n_sample_a=1000
n_sample_b=1000
m_1=np.random.normal(0,0.1,2)
m_2=np.random.normal(0,0.1,2)
v_1=invwishart(2,np.array([0.001,0.001])).rvs()
v_2=invwishart(2,np.array([0.001,0.001])).rvs()
sample_a=multivariate_normal(m_1,v_1).rvs(n_sample_a)
sample_b=multivariate_normal(m_2,v_2).rvs(n_sample_b)

f:id:cappsLk:20211203213430p:plain

まぁまぁ嫌がらせっぽい乱数が引けたのではないでしょうか。

y_a=100*sample_a[:,0]+60*sample_a[:,1] # a群
y_b=100*sample_b[:,0]+60*sample_b[:,1]+20 # b群
print(np.mean(y_b)-np.mean(y_a)) 
# 32.17765929250003

32とだいぶ大きめに効果を見積もっています。散布図からわかるようにz=1のサンプルの方がx1, x2ともに大きい値を取りがちなので、その影響が差分に加算されていると考えられます。

傾向スコアを活用したバイアス除去

まず傾向スコアを導入します。これはサンプルに対応するp(z=1|x)、すなわち観測されたxが介入が実施されるかどうかの確率です。

これは真の値がわかってるわけではないので、介入の有無を分類する機械学習手法を用いて近似値を得ます。簡単のためLogistic Regressionを使いましょう。

from sklearn.linear_model import LogisticRegression
z=np.concatenate([np.zeros(n_sample_a),np.ones(n_sample_b)],0)
samples=np.concatenate([sample_a,sample_b],0)
classifier=LogisticRegression()
classifier.fit(samples,z)
propensity_scores=classifier.predict_proba(samples)
score_a=propensity_scores[:n_sample_a,1]
score_b=propensity_scores[n_sample_a:,1]

推定された傾向スコアのヒストグラムは以下のようになります。 f:id:cappsLk:20211203214306p:plain

そういう識別器を作ってるので当たり前なんですが、z=1の時の方が高い数値になります。

この値を使って、バイアスを除去してみましょう。

傾向スコアマッチング

z=1のサンプル各々に対して、傾向スコアが最も近いz=0のサンプルとの差を取り、その平均を効果とする。

expect2=0
for y,score in zip(y_b,score_b):
    expect2+=(y-y_a[np.argmin([np.abs(score-j) for j in score_a])])/len(score_b)
print(expect2)
# 20.952883603594497

悪くない数値が出ました。

何が起こっているのか

f:id:cappsLk:20211203215302p:plain この図は何となく自分の解釈用に書いたのですが、先ほどの散布図のデータ点に対して、Logistic Regressionをしているので、多分図の茶色の線のような識別境界を引きます。そして右上が傾向スコア高、左下が傾向スコアの領域低になるはずです。

で、オレンジのz=1のサンプルに対して、傾向スコアが近い値との差分を取ろうとするので、z=0の中でも傾向スコアが高い境界面付近のサンプル(赤楕円内)のもののみを選択して差分を取ろうとしています。結果としてz=0のサンプルの中でも数値が大きいものを選択して差分を計算し、z=0サンプル特有のxの小ささによるバイアスを低減していると解釈できます。

ここで気が付いたのですが、z=0とz=1が完全に別の分布というよりは、z=0のサンプルの一部がz=1のサンプルになっているみたいな、ある程度重複がある関係の時にこの手法は良い近似が得られる気がしてきました。現実的にも多分そういうケースが多いと思うので、まぁそういうケースが多いからよく機能するのでしょう。

極端な例を挙げると、例えばこんな散布図、 f:id:cappsLk:20211203220855p:plain

傾向スコアのヒストグラムは、 f:id:cappsLk:20211203220928p:plain

潔い位完全に分離できてますね。傾向スコアが近いサンプルが見つからないなんてことになってしまいます。

最初のナイーブな差分が-46.57、傾向スコアマッチングの結果が-44.27。わずかに気持ち程度真の値の20に近づいていますね。多分z=0の左上のサンプルとばかりマッチングしてバイアスを減らそうとするのでしょうけど、まぁそもそもこんなに離れてちゃバイアス取り除けなさそうですね。

分布に重複があり、なおかつz=0側にサンプルが潤沢にある場合、良いマッチング相手が見つかって、良い近似が得られる手法なのかなと理解しました。

あとは純粋にマッチングが計算量的に高コストなので、サンプル数が多いとつらいってデメリットもありそうですね。

終わりに

というわけでダミーデータに対して傾向スコアマッチングを適用して、真の介入による効果を推定するということをしてみました。手法自体は簡単なアイディアで、実装も容易ですがなかなかのものです。

しかしながら、当然万能なんてことは無く、うまく効く制約というのもありそうというのが何となくわかってきました。こういう理解が深められるのがダミーデータで遊ぶ魅力かと思います。