diff --git a/msg.go b/msg.go index 9686dac6..286e5331 100644 --- a/msg.go +++ b/msg.go @@ -7,8 +7,7 @@ package mail import ( "bytes" "context" - "crypto/rsa" - "crypto/x509" + "crypto/tls" "embed" "errors" "fmt" @@ -208,8 +207,8 @@ func WithNoDefaultUserAgent() MsgOption { } // SignWithSMime configures the Msg to be signed with S/MIME -func (m *Msg) SignWithSMime(privateKey *rsa.PrivateKey, certificate *x509.Certificate) error { - sMime, err := NewSMime(privateKey, certificate) +func (m *Msg) SignWithSMime(keyPair *tls.Certificate) error { + sMime, err := newSMime(keyPair) if err != nil { return err } @@ -985,28 +984,31 @@ func (m *Msg) applyMiddlewares(msg *Msg) *Msg { // signMessage sign the Msg with S/MIME func (m *Msg) signMessage(msg *Msg) (*Msg, error) { - currentPart := m.GetParts()[0] - currentPart.SetEncoding(EncodingUSASCII) - currentPart.SetContentType(TypeTextPlain) - content, err := currentPart.GetContent() + signedPart := msg.GetParts()[0] + body, err := signedPart.GetContent() if err != nil { - return nil, errors.New("failed to extract content from part") + return nil, err } - signedContent, err := m.sMime.Sign(content) + signaturePart, err := m.createSignaturePart(signedPart.GetEncoding(), signedPart.GetContentType(), signedPart.GetCharset(), body) if err != nil { - return nil, errors.New("failed to sign message") + return nil, err } - signedPart := msg.newPart( - typeSMimeSigned, - WithPartEncoding(EncodingB64), - WithContentDisposition(DispositionSMime), - ) - signedPart.SetContent(*signedContent) - msg.parts = append(msg.parts, signedPart) + m.parts = append(m.parts, signaturePart) - return msg, nil + return m, err +} + +func (m *Msg) createSignaturePart(encoding Encoding, contentType ContentType, charSet Charset, body []byte) (*Part, error) { + message := m.sMime.createMessage(encoding, contentType, charSet, body) + signaturePart := m.newPart(typeSMimeSigned, WithPartEncoding(EncodingB64), WithSMimeSinging()) + + if err := m.sMime.sign(signaturePart, message); err != nil { + return nil, err + } + + return signaturePart, nil } // WriteTo writes the formated Msg into a give io.Writer and satisfies the io.WriteTo interface @@ -1014,7 +1016,7 @@ func (m *Msg) WriteTo(writer io.Writer) (int64, error) { mw := &msgWriter{writer: writer, charset: m.charset, encoder: m.encoder} msg := m.applyMiddlewares(m) - if m.sMime != nil { + if m.hasSMime() { signedMsg, err := m.signMessage(msg) if err != nil { return 0, err @@ -1210,7 +1212,7 @@ func (m *Msg) hasAlt() bool { count++ } } - return count > 1 && m.pgptype == 0 + return count > 1 && m.pgptype == 0 && !m.hasSMime() } // hasMixed returns true if the Msg has mixed parts diff --git a/msgwriter.go b/msgwriter.go index 372e1e1a..7c1e88e5 100644 --- a/msgwriter.go +++ b/msgwriter.go @@ -100,7 +100,7 @@ func (mw *msgWriter) writeMsg(msg *Msg) { mw.startMP(MIMERelated, msg.boundary) mw.writeString(DoubleNewLine) } - if msg.hasAlt() && !msg.hasSMime() { + if msg.hasAlt() { mw.startMP(MIMEAlternative, msg.boundary) mw.writeString(DoubleNewLine) } @@ -241,7 +241,7 @@ func (mw *msgWriter) addFiles(files []*File, isAttachment bool) { } if mw.err == nil { - mw.writeBody(file.Writer, encoding) + mw.writeBody(file.Writer, encoding, false) } } } @@ -273,7 +273,7 @@ func (mw *msgWriter) writePart(part *Part, charset Charset) { mimeHeader.Add(string(HeaderContentTransferEnc), contentTransferEnc) mw.newPart(mimeHeader) } - mw.writeBody(part.writeFunc, part.encoding) + mw.writeBody(part.writeFunc, part.encoding, part.smime) } // writeString writes a string into the msgWriter's io.Writer interface @@ -322,7 +322,7 @@ func (mw *msgWriter) writeHeader(key Header, values ...string) { } // writeBody writes an io.Reader into an io.Writer using provided Encoding -func (mw *msgWriter) writeBody(writeFunc func(io.Writer) (int64, error), encoding Encoding) { +func (mw *msgWriter) writeBody(writeFunc func(io.Writer) (int64, error), encoding Encoding, singingWithSMime bool) { var writer io.Writer var encodedWriter io.WriteCloser var n int64 @@ -337,12 +337,11 @@ func (mw *msgWriter) writeBody(writeFunc func(io.Writer) (int64, error), encodin lineBreaker := Base64LineBreaker{} lineBreaker.out = &writeBuffer - switch encoding { - case EncodingQP: + if encoding == EncodingQP { encodedWriter = quotedprintable.NewWriter(&writeBuffer) - case EncodingB64: + } else if encoding == EncodingB64 && !singingWithSMime { encodedWriter = base64.NewEncoder(base64.StdEncoding, &lineBreaker) - case NoEncoding: + } else if encoding == NoEncoding { _, err = writeFunc(&writeBuffer) if err != nil { mw.err = fmt.Errorf("bodyWriter function: %w", err) @@ -355,7 +354,7 @@ func (mw *msgWriter) writeBody(writeFunc func(io.Writer) (int64, error), encodin mw.bytesWritten += n } return - default: + } else { encodedWriter = quotedprintable.NewWriter(writer) } diff --git a/sime.go b/sime.go index f230a699..fe051589 100644 --- a/sime.go +++ b/sime.go @@ -1,18 +1,20 @@ package mail import ( + "bytes" "crypto/rsa" + "crypto/tls" "crypto/x509" + "encoding/pem" "errors" + "fmt" "go.mozilla.org/pkcs7" + "strings" ) var ( - // ErrInvalidPrivateKey should be used if private key is invalid - ErrInvalidPrivateKey = errors.New("invalid private key") - - // ErrInvalidCertificate should be used if certificate is invalid - ErrInvalidCertificate = errors.New("invalid certificate") + // ErrInvalidKeyPair should be used if key pair is invalid + ErrInvalidKeyPair = errors.New("invalid key pair") // ErrCouldNotInitialize should be used if the signed data could not initialize ErrCouldNotInitialize = errors.New("could not initialize signed data") @@ -22,49 +24,136 @@ var ( // ErrCouldNotFinishSigning should be used if the signing could not be finished ErrCouldNotFinishSigning = errors.New("could not finish signing") + + // ErrCouldNoEncodeToPEM should be used if the signature could not be encoded to PEM + ErrCouldNoEncodeToPEM = errors.New("could not encode to PEM") ) // SMime is used to sign messages with S/MIME type SMime struct { - privateKey *rsa.PrivateKey - certificate *x509.Certificate + privateKey *rsa.PrivateKey + certificate *x509.Certificate + parentCertificates []*x509.Certificate } -// NewSMime construct a new instance of SMime with a provided *rsa.PrivateKey -func NewSMime(privateKey *rsa.PrivateKey, certificate *x509.Certificate) (*SMime, error) { - if privateKey == nil { - return nil, ErrInvalidPrivateKey +// NewSMime construct a new instance of SMime with a provided *tls.Certificate +func newSMime(keyPair *tls.Certificate) (*SMime, error) { + if keyPair == nil { + return nil, ErrInvalidKeyPair } - if certificate == nil { - return nil, ErrInvalidCertificate + parentCertificates := make([]*x509.Certificate, 0) + for _, cert := range keyPair.Certificate[1:] { + c, err := x509.ParseCertificate(cert) + if err != nil { + return nil, err + } + parentCertificates = append(parentCertificates, c) } return &SMime{ - privateKey: privateKey, - certificate: certificate, + privateKey: keyPair.PrivateKey.(*rsa.PrivateKey), + certificate: keyPair.Leaf, + parentCertificates: parentCertificates, }, nil } -// Sign the content with the given privateKey of the method NewSMime -func (sm *SMime) Sign(content []byte) (*string, error) { - toBeSigned, err := pkcs7.NewSignedData(content) +// sign with the S/MIME method the message of the actual *Part +func (sm *SMime) sign(signaturePart *Part, message string) error { + lines := parseLines([]byte(message)) + toBeSigned := lines.bytesFromLines([]byte("\r\n")) - toBeSigned.SetDigestAlgorithm(pkcs7.OIDDigestAlgorithmSHA256) + tmp, err := pkcs7.NewSignedData(toBeSigned) + tmp.SetDigestAlgorithm(pkcs7.OIDDigestAlgorithmSHA256) if err != nil { - return nil, ErrCouldNotInitialize + return ErrCouldNotInitialize + } + + if err = tmp.AddSignerChain(sm.certificate, sm.privateKey, sm.parentCertificates, pkcs7.SignerInfoConfig{}); err != nil { + return ErrCouldNotAddSigner } - if err = toBeSigned.AddSigner(sm.certificate, sm.privateKey, pkcs7.SignerInfoConfig{}); err != nil { - return nil, ErrCouldNotAddSigner + signatureDER, err := tmp.Finish() + if err != nil { + return ErrCouldNotFinishSigning } - signed, err := toBeSigned.Finish() + pemMsg, err := encodeToPEM(signatureDER) if err != nil { - return nil, ErrCouldNotFinishSigning + return ErrCouldNoEncodeToPEM + } + signaturePart.SetContent(*pemMsg) + + return nil +} + +// createMessage prepares the message that will be used for the sign method later +func (sm *SMime) createMessage(encoding Encoding, contentType ContentType, charset Charset, body []byte) string { + return fmt.Sprintf("Content-Transfer-Encoding: %v\r\nContent-Type: %v; charset=%v\r\n\r\n%v", encoding, contentType, charset, string(body)) +} + +// encodeToPEM uses the method pem.Encode from the standard library but cuts the typical PEM preamble +func encodeToPEM(msg []byte) (*string, error) { + block := &pem.Block{Bytes: msg} + + var arrayBuffer bytes.Buffer + if err := pem.Encode(&arrayBuffer, block); err != nil { + return nil, err } - signedData := string(signed) + r := arrayBuffer.String() + r = strings.ReplaceAll(r, "-----BEGIN -----\n", "") + r = strings.ReplaceAll(r, "-----END -----\n", "") + + return &r, nil +} + +// line is the representation of one line of the message that will be used for signing purposes +type line struct { + line []byte + endOfLine []byte +} - return &signedData, nil +// lines is the representation of a message that will be used for signing purposes +type lines []line + +// bytesFromLines creates the line representation with the given endOfLine char +func (ls lines) bytesFromLines(sep []byte) []byte { + var raw []byte + for i := range ls { + raw = append(raw, ls[i].line...) + if len(ls[i].endOfLine) != 0 && sep != nil { + raw = append(raw, sep...) + } else { + raw = append(raw, ls[i].endOfLine...) + } + } + return raw +} + +// parseLines constructs the lines representation of a given message +func parseLines(raw []byte) lines { + oneLine := line{raw, nil} + lines := lines{oneLine} + lines = lines.splitLine([]byte("\r\n")) + lines = lines.splitLine([]byte("\r")) + lines = lines.splitLine([]byte("\n")) + return lines +} + +// splitLine uses the given endOfLine to split the given line +func (ls lines) splitLine(sep []byte) lines { + nl := lines{} + for _, l := range ls { + split := bytes.Split(l.line, sep) + if len(split) > 1 { + for i := 0; i < len(split)-1; i++ { + nl = append(nl, line{split[i], sep}) + } + nl = append(nl, line{split[len(split)-1], l.endOfLine}) + } else { + nl = append(nl, l) + } + } + return nl }