diff --git a/mysql-haskell.cabal b/mysql-haskell.cabal index 237d457..ea41c9f 100644 --- a/mysql-haskell.cabal +++ b/mysql-haskell.cabal @@ -93,7 +93,8 @@ library time >=1.5.0 && <1.12 || ^>=1.12.2 || ^>=1.14, tls >=1.7.0 && <1.8 || ^>=1.8.0 || ^>=1.9.0 || ^>=2.0.0 || ^>=2.1.0, vector >=0.8 && <0.13 || ^>=0.13.0, - word-compat >=0.0 && <0.1 + word-compat >=0.0 && <0.1, + cryptostore default-language: Haskell2010 default-extensions: diff --git a/src/Database/MySQL/Connection.hs b/src/Database/MySQL/Connection.hs index a0f6daf..c502ad5 100644 --- a/src/Database/MySQL/Connection.hs +++ b/src/Database/MySQL/Connection.hs @@ -15,8 +15,12 @@ module Database.MySQL.Connection where import Control.Exception (Exception, bracketOnError, throwIO, catch, SomeException) +import Crypto.Hash.Algorithms (SHA1 (..)) import Control.Monad import qualified Crypto.Hash as Crypto +import qualified Crypto.Store.X509 as X509 +import qualified Data.X509 as X509 +import qualified Crypto.PubKey.RSA.OAEP as RSA import qualified Data.Binary as Binary import qualified Data.Binary.Put as Binary import Data.Bits @@ -116,6 +120,28 @@ connectDetail (ConnectInfo host port db user pass charset) greet <- decodeFromPacket p let auth = mkAuth db user pass charset greet write c $ encodeToPacket 1 auth + p2 <- readPacket is' + if pBody p2 == "\x01\x04" -- Full authentication + then do + -- TODO: unix socketやTLSを使っている場合は別の処理が必要 + write c $ encodeToPacket 3 RequestPubKey + p3 <- readPacket is' + let textKey = L.toStrict $ L.drop 1 (pBody p3) + let keys = X509.readPubKeyFileFromMemory textKey + let pubkey = case keys of + [X509.PubKeyRSA key] -> key + _ -> error "Invalid RSA key" + let nonce = greetingSalt1 greet `B.append` greetingSalt2 greet + let xorPass = B.pack $ zipWith xor (B.unpack (B.append pass "\0")) (cycle (B.unpack nonce)) + eEncrypted <- RSA.encrypt (RSA.defaultOAEPParams SHA1) pubkey xorPass + let encrypted = case eEncrypted of + Left err -> error $ show err + Right enc -> enc + write c $ encodeToPacket 5 $ SendEncryptedPassword encrypted + pure () + else if pBody p2 == "\01\x03" -- Fast authentication OK + then pure () + else throwIO (UnexpectedPacket p2) q <- readPacket is' if isOK q then do @@ -136,18 +162,18 @@ connectDetail (ConnectInfo host port db user pass charset) mkAuth :: ByteString -> ByteString -> ByteString -> Word8 -> Greeting -> Auth mkAuth db user pass charset greet = let salt = greetingSalt1 greet `B.append` greetingSalt2 greet - scambleBuf = scramble salt pass + scambleBuf = scrambleSha256 salt pass in Auth clientCap clientMaxPacketSize charset user scambleBuf db - where - scramble :: ByteString -> ByteString -> ByteString - scramble salt pass' - | B.null pass' = B.empty - | otherwise = B.pack (B.zipWith xor sha1pass withSalt) - where sha1pass = sha1 pass' - withSalt = sha1 (salt `B.append` sha1 sha1pass) - - sha1 :: ByteString -> ByteString - sha1 = BA.convert . (Crypto.hash :: ByteString -> Crypto.Digest Crypto.SHA1) + +scrambleSha256 :: ByteString -> ByteString -> ByteString +scrambleSha256 salt pass + | B.null pass = B.empty + | otherwise = B.pack (B.zipWith xor sha256pass withSalt256) + where sha256pass = sha256 pass + withSalt256 = sha256_2 (sha256 sha256pass) salt + sha256 = BA.convert . (Crypto.hash :: ByteString -> Crypto.Digest Crypto.SHA256) + sha256_2 bytes1 bytes2 = BA.convert $ Crypto.hashFinalize (Crypto.hashUpdate (Crypto.hashUpdate (Crypto.hashInit :: Crypto.Context Crypto.SHA256) bytes1) bytes2) + -- | A specialized 'decodeInputStream' here for speed decodeInputStream :: InputStream ByteString -> IO (InputStream Packet) diff --git a/src/Database/MySQL/Protocol/Auth.hs b/src/Database/MySQL/Protocol/Auth.hs index 58b51c8..eed3fae 100644 --- a/src/Database/MySQL/Protocol/Auth.hs +++ b/src/Database/MySQL/Protocol/Auth.hs @@ -144,6 +144,26 @@ putAuth (Auth cap m c n p s) = do putByteString p putByteString s putWord8 0x00 + putByteString "caching_sha2_password" + putWord8 0x00 + +data RequestPubKey = RequestPubKey deriving (Show, Eq) + +putRequestPubKey :: RequestPubKey -> Put +putRequestPubKey _ = putWord8 0x02 + +instance Binary RequestPubKey where + put = putRequestPubKey + get = pure RequestPubKey + +data SendEncryptedPassword = SendEncryptedPassword !ByteString deriving (Show, Eq) + +putSendEncryptedPassword :: SendEncryptedPassword -> Put +putSendEncryptedPassword (SendEncryptedPassword p) = putByteString p + +instance Binary SendEncryptedPassword where + put = putSendEncryptedPassword + get = SendEncryptedPassword <$> getByteStringNul instance Binary Auth where get = getAuth @@ -182,6 +202,7 @@ clientCap = CLIENT_LONG_PASSWORD .|. CLIENT_MULTI_STATEMENTS .|. CLIENT_MULTI_RESULTS .|. CLIENT_SECURE_CONNECTION + .|. CLIENT_PLUGIN_AUTH clientMaxPacketSize :: Word32 clientMaxPacketSize = 0x00ffffff :: Word32