diff --git a/internal/home/certificate.go b/internal/home/certificate.go index bb0902f8d..fc3d5e2e8 100644 --- a/internal/home/certificate.go +++ b/internal/home/certificate.go @@ -6,9 +6,11 @@ import ( "context" "crypto/rand" "crypto/rsa" + "crypto/sha256" "crypto/x509" "crypto/x509/pkix" "encoding/base64" + "encoding/hex" "encoding/json" "encoding/pem" "fmt" @@ -26,10 +28,13 @@ import ( const homeCertificateRequestTimeout = 30 * time.Second type homeJWTClaims struct { - CertificateID string `json:"certificate_id"` - IP string `json:"ip"` - Port int `json:"port"` - IssuedAt int64 `json:"iat"` + CertificateID string `json:"certificate_id"` + ClusterID string `json:"cluster_id"` + CAFingerprint string `json:"ca_fingerprint"` + EnrollmentSecret string `json:"enrollment_secret"` + IP string `json:"ip"` + Port int `json:"port"` + IssuedAt int64 `json:"iat"` } type certificateRequestResponse struct { @@ -88,6 +93,15 @@ func parseHomeJWTClaims(rawJWT string) (homeJWTClaims, error) { if strings.TrimSpace(claims.CertificateID) == "" { return claims, fmt.Errorf("home jwt certificate_id is required") } + if strings.TrimSpace(claims.ClusterID) == "" { + return claims, fmt.Errorf("home jwt cluster_id is required") + } + if normalizeFingerprint(claims.CAFingerprint) == "" { + return claims, fmt.Errorf("home jwt ca_fingerprint is required") + } + if strings.TrimSpace(claims.EnrollmentSecret) == "" { + return claims, fmt.Errorf("home jwt enrollment_secret is required") + } if strings.TrimSpace(claims.IP) == "" || claims.Port <= 0 { return claims, fmt.Errorf("home jwt target address is invalid") } @@ -120,6 +134,9 @@ func ensureHomeCertificateFiles(ctx context.Context, claims homeJWTClaims, paths if !fileExists(paths.CACert) { return fmt.Errorf("home ca certificate file is missing") } + if errVerify := verifyCACertificateFile(paths.CACert, claims.CAFingerprint); errVerify != nil { + return errVerify + } if errChmod := chmodCertificateFiles(paths); errChmod != nil { return errChmod } @@ -143,6 +160,9 @@ func ensureHomeCertificateFiles(ctx context.Context, claims homeJWTClaims, paths if strings.TrimSpace(response.Certificate) == "" || strings.TrimSpace(response.CA) == "" { return fmt.Errorf("home certificate response is incomplete") } + if errVerify := verifyCACertificatePEM([]byte(response.CA), claims.CAFingerprint); errVerify != nil { + return errVerify + } if errWrite := writeFile0600(paths.ClientCert, []byte(response.Certificate)); errWrite != nil { return errWrite } @@ -152,6 +172,49 @@ func ensureHomeCertificateFiles(ctx context.Context, claims homeJWTClaims, paths return nil } +func verifyCACertificateFile(path string, expectedFingerprint string) error { + raw, errRead := os.ReadFile(path) + if errRead != nil { + return errRead + } + return verifyCACertificatePEM(raw, expectedFingerprint) +} + +func verifyCACertificatePEM(raw []byte, expectedFingerprint string) error { + actual, errFingerprint := certificateFingerprintPEM(raw) + if errFingerprint != nil { + return errFingerprint + } + expected := normalizeFingerprint(expectedFingerprint) + if expected == "" { + return fmt.Errorf("home ca fingerprint is required") + } + if actual != expected { + return fmt.Errorf("home ca fingerprint mismatch") + } + return nil +} + +func certificateFingerprintPEM(raw []byte) (string, error) { + block, _ := pem.Decode(raw) + if block == nil || block.Type != "CERTIFICATE" { + return "", fmt.Errorf("home ca certificate pem is invalid") + } + cert, errParse := x509.ParseCertificate(block.Bytes) + if errParse != nil { + return "", errParse + } + sum := sha256.Sum256(cert.Raw) + return hex.EncodeToString(sum[:]), nil +} + +func normalizeFingerprint(fingerprint string) string { + fingerprint = strings.TrimSpace(strings.ToLower(fingerprint)) + fingerprint = strings.ReplaceAll(fingerprint, ":", "") + fingerprint = strings.ReplaceAll(fingerprint, " ", "") + return fingerprint +} + func loadOrCreateClientKey(path string) (*rsa.PrivateKey, error) { if fileExists(path) { raw, errRead := os.ReadFile(path) @@ -252,7 +315,7 @@ func requestClientCertificate(ctx context.Context, claims homeJWTClaims, csrPEM if deadline, ok := dialCtx.Deadline(); ok { _ = conn.SetDeadline(deadline) } - if _, errWrite := conn.Write(encodeRESPArray("CERTIFICATE", "REQUEST", claims.CertificateID, string(csrPEM))); errWrite != nil { + if _, errWrite := conn.Write(encodeRESPArray("CERTIFICATE", "REQUEST", claims.CertificateID, claims.EnrollmentSecret, string(csrPEM))); errWrite != nil { return response, errWrite } raw, errRead := readRESPBulk(bufio.NewReader(conn))