2022-10-11 17:36:09 +08:00
package tls
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"math/big"
"net"
"os"
"time"
)
var isDev bool
// CreateServerTLSConfig creates server tls config.
func CreateServerTLSConfig ( host string ) ( * tls . Config , error ) {
// development mode
if isDev {
tc , err := developmentTLSConfig ( host )
if err != nil {
return nil , err
}
return tc , nil
}
// production mode
// ca pool
pool , err := getCACertPool ( )
if err != nil {
return nil , err
}
// server certificate
tlsCert , err := getCertAndKey ( )
if err != nil {
return nil , err
}
return & tls . Config {
Certificates : [ ] tls . Certificate { * tlsCert } ,
ClientCAs : pool ,
ClientAuth : tls . RequireAndVerifyClientCert ,
NextProtos : [ ] string { "hpds" } ,
} , nil
}
// CreateClientTLSConfig creates client tls config.
func CreateClientTLSConfig ( ) ( * tls . Config , error ) {
// development mode
if isDev {
return & tls . Config {
InsecureSkipVerify : true ,
NextProtos : [ ] string { "hpds" } ,
ClientSessionCache : tls . NewLRUClientSessionCache ( 64 ) ,
} , nil
}
// production mode
pool , err := getCACertPool ( )
if err != nil {
return nil , err
}
tlsCert , err := getCertAndKey ( )
if err != nil {
return nil , err
}
return & tls . Config {
InsecureSkipVerify : false ,
Certificates : [ ] tls . Certificate { * tlsCert } ,
RootCAs : pool ,
NextProtos : [ ] string { "hpds" } ,
ClientSessionCache : tls . NewLRUClientSessionCache ( 0 ) ,
} , nil
}
func getCACertPool ( ) ( * x509 . CertPool , error ) {
var err error
var caCert [ ] byte
caCertPath := os . Getenv ( "HPDS_TLS_CACERT_FILE" )
if len ( caCertPath ) == 0 {
return nil , errors . New ( "tls: must provide CA certificate on production mode, you can configure this via environment variables: `HPDS_TLS_CACERT_FILE`" )
}
2023-03-10 23:49:52 +08:00
caCert , err = os . ReadFile ( caCertPath )
2022-10-11 17:36:09 +08:00
if err != nil {
return nil , err
}
if len ( caCert ) == 0 {
return nil , errors . New ( "tls: cannot load CA cert" )
}
pool := x509 . NewCertPool ( )
if ok := pool . AppendCertsFromPEM ( caCert ) ; ! ok {
return nil , errors . New ( "tls: cannot append CA cert to pool" )
}
return pool , nil
}
func getCertAndKey ( ) ( * tls . Certificate , error ) {
var err error
var cert , key [ ] byte
certPath := os . Getenv ( "HPDS_TLS_CERT_FILE" )
keyPath := os . Getenv ( "HPDS_TLS_KEY_FILE" )
if len ( certPath ) == 0 || len ( keyPath ) == 0 {
return nil , errors . New ( "tls: must provide certificate on production mode, you can configure this via environment variables: `HPDS_TLS_CERT_FILE` and `HPDS_TLS_KEY_FILE`" )
}
// certificate
2023-03-10 23:49:52 +08:00
cert , err = os . ReadFile ( certPath )
2022-10-11 17:36:09 +08:00
if err != nil {
return nil , err
}
// private key
2023-03-10 23:49:52 +08:00
key , err = os . ReadFile ( keyPath )
2022-10-11 17:36:09 +08:00
if err != nil {
return nil , err
}
if len ( cert ) == 0 || len ( key ) == 0 {
return nil , errors . New ( "tls: cannot load tls cert/key" )
}
tlsCert , err := tls . X509KeyPair ( cert , key )
if err != nil {
return nil , err
}
return & tlsCert , nil
}
// IsDev development mode
func IsDev ( ) bool {
return isDev
}
// developmentTLSConfig Setup a bare-bones TLS config for the server
func developmentTLSConfig ( host ... string ) ( * tls . Config , error ) {
tlsCert , err := generateCertificate ( host ... )
if err != nil {
return nil , err
}
return & tls . Config {
Certificates : [ ] tls . Certificate { tlsCert } ,
ClientSessionCache : tls . NewLRUClientSessionCache ( 1 ) ,
NextProtos : [ ] string { "hpds" } ,
} , nil
}
func generateCertificate ( host ... string ) ( tls . Certificate , error ) {
priv , err := ecdsa . GenerateKey ( elliptic . P256 ( ) , rand . Reader )
if err != nil {
return tls . Certificate { } , err
}
notBefore := time . Now ( )
notAfter := notBefore . Add ( time . Hour * 24 * 365 )
serialNumberLimit := new ( big . Int ) . Lsh ( big . NewInt ( 1 ) , 128 )
serialNumber , err := rand . Int ( rand . Reader , serialNumberLimit )
if err != nil {
return tls . Certificate { } , err
}
template := x509 . Certificate {
SerialNumber : serialNumber ,
Subject : pkix . Name {
Organization : [ ] string { "HPDS" } ,
} ,
NotBefore : notBefore ,
NotAfter : notAfter ,
KeyUsage : x509 . KeyUsageKeyEncipherment | x509 . KeyUsageDigitalSignature ,
ExtKeyUsage : [ ] x509 . ExtKeyUsage { x509 . ExtKeyUsageServerAuth } ,
BasicConstraintsValid : true ,
DNSNames : [ ] string { "localhost" } ,
}
for _ , h := range host {
if ip := net . ParseIP ( h ) ; ip != nil {
template . IPAddresses = append ( template . IPAddresses , ip )
} else {
template . DNSNames = append ( template . DNSNames , h )
}
}
template . IsCA = true
template . KeyUsage |= x509 . KeyUsageCertSign
derBytes , err := x509 . CreateCertificate ( rand . Reader , & template , & template , & priv . PublicKey , priv )
if err != nil {
return tls . Certificate { } , err
}
// create public key
certOut := bytes . NewBuffer ( nil )
err = pem . Encode ( certOut , & pem . Block { Type : "CERTIFICATE" , Bytes : derBytes } )
if err != nil {
return tls . Certificate { } , err
}
// create private key
keyOut := bytes . NewBuffer ( nil )
b , err := x509 . MarshalECPrivateKey ( priv )
if err != nil {
return tls . Certificate { } , err
}
err = pem . Encode ( keyOut , & pem . Block { Type : "EC PRIVATE KEY" , Bytes : b } )
if err != nil {
return tls . Certificate { } , err
}
return tls . X509KeyPair ( certOut . Bytes ( ) , keyOut . Bytes ( ) )
}
func init ( ) {
env := os . Getenv ( "HPDS_ENV" )
isDev = len ( env ) == 0 || env != "production"
}