トップ «前の日記(2013-04-09) 最新 次の日記(2013-11-28)» 月表示 編集

日々の流転


2013-11-26 [長年日記]

λ. 自動微分を使って線形回帰をしてみる

Conal Elliott の Beautiful differentiation の論文を読んだ話 の続き。 自動微分の機械学習系での応用は分かるという話で、せっかくなので一番簡単な線形回帰でも使って試してみることに。

データ例

例としては、Rでお馴染みのgaltonのデータを使う。 これは、リンク先にも書いてあるけど、Galtonさんが1885に両親の身長と子供の身長の関係を分析したデータセットで、928サンプルが含まれている。RのUsingRパッケージに含まれているので、これをRで以下のようにして、galton.csvに保存しておく。

> install.packages("UsingR")
> library(UsingR)
> data(galton)
> write.csv(galton, "galton.csv", row.names = FALSE)

せっかくなので、Rで線形回帰をして散布図上にプロットしておく。

> lm <- lm(galton$child ~ galton$parent)
> summary(lm)

Call:
lm(formula = galton$child ~ galton$parent)

Residuals:
    Min      1Q  Median      3Q     Max 
-7.8050 -1.3661  0.0487  1.6339  5.9264 

Coefficients:
              Estimate Std. Error t value Pr(>|t|)    
(Intercept)   23.94153    2.81088   8.517   <2e-16 ***
galton$parent  0.64629    0.04114  15.711   <2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 

Residual standard error: 2.239 on 926 degrees of freedom
Multiple R-squared: 0.2105,	Adjusted R-squared: 0.2096 
F-statistic: 246.8 on 1 and 926 DF,  p-value: < 2.2e-16

> plot(galton$parent, galton$child, pch=19, col="blue")
> lines(galton$parent, lm$fitted, col="red", lwd=3)

自動微分のコード

Conal Elliott のコードを使おうかと思ったけれど、Hackageにekmettプロダクトのadパッケージがあったので、それを使うことに。バージョンは3.4を使う。 そういえば、先日開催されていたekmett勉強会でも、@nebutalabという方がadの紹介をしていた。

最急降下法

で、grad関数で自動微分して勾配を求めて、それを使って最急降下法するコードを自分で書こうと思ったら、ちょうど都合の良いことにgradientDescentという関数が用意されていた。 なので、これを使うことにして書いてみたのが以下のコード。 gradientDescentには普通に目的関数costと初期値を引き渡しているだけで、勾配を計算するための関数などを全く引き渡していないことに注目。

-- Galton.hs
module Main where

import Control.Monad
import Numeric.AD
import Text.Printf
import qualified Text.CSV as CSV

main :: IO ()
main = do
  Right csv <- CSV.parseCSVFromFile "galton.csv"
  let samples :: [(Double, Double)]
      samples = [(read parent, read child) | [child,parent] <- tail csv]      
      -- hypothesis
      h [theta0,theta1] x = theta0 + theta1*x
      -- cost function
      cost theta = mse [(realToFrac x, realToFrac y) | (x,y) <- samples] (h theta)
  forM_ (zip [(0::Int)..] (gradientDescent cost [0,0])) $ \(n,theta) -> do
    printf "[%d] cost = %f, theta = %s\n" n (cost theta :: Double) (show theta)

-- mean squared error
mse :: Fractional y => [(x,y)] -> (x -> y) -> y
mse samples h = sum [(h x - y)^(2::Int) | (x,y) <- samples] / fromIntegral (length samples)

実行結果

それで、コンパイルして実行してみると……

[0] cost = 3156.057306315988, theta = [2.6597058526401096e-2,1.8176025390625015]
[1] cost = 2146.158122210389, theta = [4.6848451721258726e-3,0.3193588373256373]
[2] cost = 1459.9671459993929, theta = [2.2758658044906822e-2,1.5543558052476183]
[3] cost = 993.724511817967, theta = [7.872126807824226e-3,0.5363518751316139]
[4] cost = 676.9290401789791, theta = [2.0154694031410444e-2,1.3754888484523264]
[5] cost = 461.6776592383078, theta = [1.0041866689573948e-2,0.6837909525010937]
[6] cost = 315.42191649062323, theta = [1.8389486179011837e-2,1.2539549823963831]
[7] cost = 216.0462832295627, theta = [1.1520222735584755e-2,0.783970560056418]
[8] cost = 148.52403553080242, theta = [1.719418363575862e-2,1.1713769350191798]
[9] cost = 102.64504299700745, theta = [1.252880777988588e-2,0.8520390200495115]
…
[100] cost = 5.391540483678139, theta = [1.5252813280501425e-2,0.9963219150471427]
[101] cost = 5.391540272870362, theta = [1.5259169016306468e-2,0.9963197390947276]
[102] cost = 5.391540062688162, theta = [1.5265580330093717e-2,0.9963213622860531]
[103] cost = 5.3915398529310075, theta = [1.5271945828031475e-2,0.9963198538562846]
[104] cost = 5.391539446704301, theta = [1.5284752349382588e-2,0.996321999765866]
…
[1000] cost = 5.3913214546153885, theta = [2.195100305728909e-2,0.9962239448356897]
[1001] cost = 5.39132107512034, theta = [2.19637098242177e-2,0.9962195158083134]
[1002] cost = 5.391320924102971, theta = [2.1976643022454955e-2,0.996230564935548]
[1003] cost = 5.391320615754077, theta = [2.198280974693417e-2,0.9962155918323486]
[1004] cost = 5.391320339333709, theta = [2.1989373588151114e-2,0.9962277637190161]
…
[10000] cost = 5.389137897180453, theta = [8.88298535971496e-2,0.9952457730583962]
[10001] cost = 5.38913768668423, theta = [8.8836182876215e-2,0.9952431313378857]
[10002] cost = 5.389137477124735, theta = [8.884258017163454e-2,0.995245138984378]
[10003] cost = 5.389137268201876, theta = [8.884892139830691e-2,0.995243314173662]
[10004] cost = 5.389136870256361, theta = [8.886169628000305e-2,0.9952459827149522]

全然収束しない。収束遅すぎだろ。めげずにさらに放置してると……

[100000] cost = 5.367962167905593, theta = [0.7474196350908293,0.9856182795949244]
[100001] cost = 5.367961856713982, theta = [0.7474255920526732,0.9856022065240108]
[100002] cost = 5.367961582358961, theta = [0.7474319755635072,0.9856152902779793]
[100003] cost = 5.367961333033166, theta = [0.747438007468459,0.9856043401613668]
[100004] cost = 5.367961100713857, theta = [0.7474443291979176,0.9856132010817849]
…
[1000000] cost = 5.210315054542271, theta = [6.411550087218479,0.9027479380474824]
[1000001] cost = 5.210314936517216, theta = [6.411554713147153,0.9027442449249086]
[1000002] cost = 5.210314820386824, theta = [6.411559435811994,0.9027471642783936]
[1000003] cost = 5.210314705543812, theta = [6.411564078735267,0.9027446329906037]
[1000004] cost = 5.210314591575562, theta = [6.411568787386962,0.9027465946477329]
…
[5000000] cost = 5.0177281352658305, theta = [18.890815946868038,0.7201815311956332]
[5000001] cost = 5.017728123133143, theta = [18.890817258298373,0.7201790053001258]
[5000002] cost = 5.017728111906337, theta = [18.89081863661461,0.720181051408368]
[5000003] cost = 5.017728101294881, theta = [18.89081995979637,0.7201793288296293]
[5000004] cost = 5.01772809110162, theta = [18.890821328424636,0.7201807127666369]
…
[10000000] cost = 5.001070611475556, theta = [22.875383075131683,0.6618880826205324]
[10000001] cost = 5.00107061095209, theta = [22.87538335250446,0.6618875866914469]
[10000002] cost = 5.001070610463426, theta = [22.875383643001708,0.6618879878893927]
[10000003] cost = 5.001070609998513, theta = [22.875383922680363,0.6618876495886096]
[10000004] cost = 5.001070609549686, theta = [22.87538421127661,0.6618879208540909]

まだかかるのかよ……

[15000000] cost = 5.000328380487849, theta = [23.716478910924902,0.6495830291863548]
[15000001] cost = 5.0003283804652, theta = [23.71647896958202,0.6495829318117019]
[15000002] cost = 5.000328380443848, theta = [23.71647903081446,0.6495830104741416]
[15000003] cost = 5.000328380423461, theta = [23.716479089924043,0.6495829440298111]
[15000004] cost = 5.000328380403642, theta = [23.71647915078346,0.649582997196493]
…
[20000000] cost = 5.0002953079344685, theta = [23.894024471635685,0.6469855785633368]
[20000001] cost = 5.000295307933508, theta = [23.894024484038532,0.64698555944479]
[20000002] cost = 5.000295307932569, theta = [23.89402449694667,0.6469855748657359]
[20000003] cost = 5.0002953079316805, theta = [23.894024509438292,0.6469855618158963]
[20000004] cost = 5.000295307930787, theta = [23.89402452227324,0.6469855722344278]
…
[22987995] cost = 5.000294005873883, theta = [23.922778217862263,0.6465649143875803]
[22987996] cost = 5.000294005873744, theta = [23.922778220364872,0.6465649143543949]
[22987997] cost = 5.000294005873735, theta = [23.922778220366094,0.6465649143543771]
[22987998] cost = 5.000294005873733, theta = [23.922778220367316,0.6465649143543594]
[22987999] cost = 5.000294005873733, theta = [23.922778220367316,0.6465649143543594]
[22988000] cost = 5.000294005873733, theta = [23.922778220367316,0.6465649143543594]

やった! 収束した!

Rだと一瞬だったところ、1.5日くらいかかったし、Rで計算した theta0=23.94153, theta1=0.64629 とはちょっとずれてはいるけれど。

感想

Courseraの Machine Learning のコースで、最急降下法で線形回帰を実装したときは1000ステップくらいで収束してたので、これもそれくらいで収束するかと思ってたんだけれど、こんなにかかるとは以外だった。 gradientDescent関数(ソース)の実装に何か問題があるのか、それともそもそも素朴な最急降下法には荷の重い問題だったのか分からんけれど、2パラメータのこの程度の問題でこれは流石に酷い。

本当はこの後は、ロジスティック回帰とニューラルネットの例も何かやってみようかと思ってたのだけれど、その前にそっちを何とかしないといけないかも。 Machine Learning のコースでも最初は簡単な再急降下法を自前で実装させていたけれど、その後は最適化は予め用意してある実装(fmincg.m)を使うようになってたからなぁ。 まあ、時間かかりすぎて疲れたので、また今度。

【追記】頂いたコメンドなど

【追記】 解決策

本日のツッコミ(全2件) [ツッコミを入れる]
ψ mkotha (2013-11-28 18:52)

入力をおおざっぱに正規化(68を引いてから4で割る)してから最適化したところ、99ステップで(Doubleの最後の桁が動かなくなるまで)収束しました

ψ sakai (2013-11-28 21:51)

おお〜! そういえば、Courseraの講義でも feature scaling が勾配法の収束の速度に効いてくるというようなことを言っていたような気がしますが、初めて実感しました。