a73x

bd347edc

feat: add MITM CONNECT handling with body scanning

a73x   2026-03-29 16:38

Add WithCA and WithUpstreamTLS options; replace blind tunnel with TLS
interception when a CA is configured, scanning decrypted request bodies.
Also extend ca.GenerateLeaf to support IP SAN certificates.

diff --git a/ca/ca.go b/ca/ca.go
index 90b0b2c..a1cb414 100644
--- a/ca/ca.go
+++ b/ca/ca.go
@@ -10,6 +10,7 @@ import (
	"encoding/pem"
	"errors"
	"math/big"
	"net"
	"os"
	"path/filepath"
	"time"
@@ -138,7 +139,6 @@ func GenerateLeaf(host string, caCert *x509.Certificate, caKey *ecdsa.PrivateKey
		Subject: pkix.Name{
			CommonName: host,
		},
		DNSNames:  []string{host},
		NotBefore: time.Now().Add(-time.Minute),
		NotAfter:  time.Now().Add(24 * time.Hour),
		KeyUsage:  x509.KeyUsageDigitalSignature,
@@ -147,6 +147,12 @@ func GenerateLeaf(host string, caCert *x509.Certificate, caKey *ecdsa.PrivateKey
		},
	}

	if ip := net.ParseIP(host); ip != nil {
		template.IPAddresses = []net.IP{ip}
	} else {
		template.DNSNames = []string{host}
	}

	certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, &key.PublicKey, caKey)
	if err != nil {
		return tls.Certificate{}, err
diff --git a/proxy/proxy.go b/proxy/proxy.go
index 616636e..109d2a7 100644
--- a/proxy/proxy.go
+++ b/proxy/proxy.go
@@ -3,6 +3,9 @@ package proxy
import (
	"bufio"
	"bytes"
	"crypto/ecdsa"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"log"
@@ -10,7 +13,9 @@ import (
	"net/http"
	"os"
	"strings"
	"sync"

	"github.com/xanderle/nono/ca"
	"github.com/xanderle/nono/scanner"
)

@@ -29,10 +34,31 @@ func WithRules(rulesPath string) Option {
	}
}

// WithCA returns an Option that enables MITM interception using the given CA cert and key.
func WithCA(cert *x509.Certificate, key *ecdsa.PrivateKey) Option {
	return func(p *Proxy) {
		p.caCert = cert
		p.caKey = key
		p.certCache = make(map[string]*tls.Certificate)
	}
}

// WithUpstreamTLS returns an Option that sets the TLS config used when dialing upstream.
func WithUpstreamTLS(cfg *tls.Config) Option {
	return func(p *Proxy) {
		p.upstreamTLS = cfg
	}
}

// Proxy is an HTTP proxy that only allows connections to approved hosts.
type Proxy struct {
	hostsFile string
	scanner   *scanner.Scanner
	hostsFile   string
	scanner     *scanner.Scanner
	caCert      *x509.Certificate
	caKey       *ecdsa.PrivateKey
	certCache   map[string]*tls.Certificate
	certMu      sync.Mutex
	upstreamTLS *tls.Config
}

// New creates a new Proxy that checks hosts against the given allowlist file.
@@ -80,6 +106,14 @@ func (p *Proxy) isApproved(host string) bool {
}

func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
	if p.caCert == nil {
		p.handleConnectTunnel(w, r)
		return
	}
	p.handleConnectMITM(w, r)
}

func (p *Proxy) handleConnectTunnel(w http.ResponseWriter, r *http.Request) {
	targetConn, err := net.Dial("tcp", r.Host)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadGateway)
@@ -107,6 +141,119 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
	targetConn.Close()
}

func (p *Proxy) handleConnectMITM(w http.ResponseWriter, r *http.Request) {
	hj, ok := w.(http.Hijacker)
	if !ok {
		http.Error(w, "hijacking not supported", http.StatusInternalServerError)
		return
	}

	clientConn, _, err := hj.Hijack()
	if err != nil {
		return
	}

	clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))

	host := extractHost(r.Host)

	leafCert, err := p.getOrCreateLeaf(host)
	if err != nil {
		log.Printf("ERROR: failed to get leaf cert for %s: %v", host, err)
		clientConn.Close()
		return
	}

	tlsConn := tls.Server(clientConn, &tls.Config{
		Certificates: []tls.Certificate{*leafCert},
	})
	if err := tlsConn.Handshake(); err != nil {
		log.Printf("ERROR: TLS handshake with client failed for %s: %v", host, err)
		tlsConn.Close()
		return
	}
	defer tlsConn.Close()

	req, err := http.ReadRequest(bufio.NewReader(tlsConn))
	if err != nil {
		log.Printf("ERROR: failed to read request from TLS conn for %s: %v", host, err)
		return
	}

	findings := p.scanRequest(req)
	if len(findings) > 0 {
		rules := make([]string, 0, len(findings))
		for _, f := range findings {
			rules = append(rules, f.Rule)
		}
		log.Printf("BLOCKED %s %s [%s]", req.Method, r.Host, strings.Join(rules, ", "))
		resp := &http.Response{
			StatusCode: http.StatusForbidden,
			ProtoMajor: 1,
			ProtoMinor: 1,
			Header:     make(http.Header),
			Body:       io.NopCloser(strings.NewReader(fmt.Sprintf("request blocked: contains sensitive data (%s)\n", strings.Join(rules, ", ")))),
		}
		resp.Header.Set("Content-Type", "text/plain")
		resp.Write(tlsConn)
		return
	}

	upstreamCfg := p.upstreamTLS
	if upstreamCfg == nil {
		upstreamCfg = &tls.Config{}
	}
	cfg := upstreamCfg.Clone()
	cfg.ServerName = host

	upstreamConn, err := tls.Dial("tcp", r.Host, cfg)
	if err != nil {
		log.Printf("ERROR: failed to dial upstream %s: %v", r.Host, err)
		resp := &http.Response{
			StatusCode: http.StatusBadGateway,
			ProtoMajor: 1,
			ProtoMinor: 1,
			Header:     make(http.Header),
			Body:       io.NopCloser(strings.NewReader(err.Error()+"\n")),
		}
		resp.Write(tlsConn)
		return
	}
	defer upstreamConn.Close()

	req.RequestURI = ""
	if err := req.Write(upstreamConn); err != nil {
		log.Printf("ERROR: failed to forward request to %s: %v", r.Host, err)
		return
	}

	upstreamResp, err := http.ReadResponse(bufio.NewReader(upstreamConn), req)
	if err != nil {
		log.Printf("ERROR: failed to read response from %s: %v", r.Host, err)
		return
	}
	defer upstreamResp.Body.Close()

	upstreamResp.Write(tlsConn)
}

func (p *Proxy) getOrCreateLeaf(host string) (*tls.Certificate, error) {
	p.certMu.Lock()
	defer p.certMu.Unlock()

	if cert, ok := p.certCache[host]; ok {
		return cert, nil
	}

	leaf, err := ca.GenerateLeaf(host, p.caCert, p.caKey)
	if err != nil {
		return nil, err
	}

	p.certCache[host] = &leaf
	return &leaf, nil
}

func (p *Proxy) scanRequest(r *http.Request) []scanner.Finding {
	if p.scanner == nil {
		return nil
diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go
index cff0dd7..2746587 100644
--- a/proxy/proxy_test.go
+++ b/proxy/proxy_test.go
@@ -1,6 +1,8 @@
package proxy_test

import (
	"crypto/tls"
	"crypto/x509"
	"net/http"
	"net/http/httptest"
	"net/url"
@@ -9,6 +11,7 @@ import (
	"strings"
	"testing"

	nca "github.com/xanderle/nono/ca"
	"github.com/xanderle/nono/proxy"
)

@@ -181,3 +184,107 @@ func TestShouldAllowCleanRequest(t *testing.T) {
		t.Errorf("expected 200, got %d", resp.StatusCode)
	}
}

func TestShouldBlockHTTPSRequestWithAWSKey(t *testing.T) {
	backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)
	host := backendURL.Hostname()

	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(host+"\n"), 0644)

	caDir := t.TempDir()
	caCert, caKey, err := nca.LoadOrCreate(caDir)
	if err != nil {
		t.Fatalf("CA setup: %v", err)
	}

	rulesPath := writeTestRules(t)
	p := proxy.New(hostsFile,
		proxy.WithRules(rulesPath),
		proxy.WithCA(caCert, caKey),
		proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}),
	)
	srv := httptest.NewServer(p)
	defer srv.Close()

	caPool := x509.NewCertPool()
	caPool.AddCert(caCert)

	proxyURL, _ := url.Parse(srv.URL)
	client := &http.Client{
		Transport: &http.Transport{
			Proxy: http.ProxyURL(proxyURL),
			TLSClientConfig: &tls.Config{
				RootCAs: caPool,
			},
		},
	}

	body := strings.NewReader("key=AKIAIOSFODNN7EXAMPLE")
	resp, err := client.Post(backend.URL+"/upload", "text/plain", body)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusForbidden {
		t.Errorf("expected 403, got %d", resp.StatusCode)
	}
}

func TestShouldAllowCleanHTTPSRequest(t *testing.T) {
	backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		w.Write([]byte("ok"))
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)
	host := backendURL.Hostname()

	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(host+"\n"), 0644)

	caDir := t.TempDir()
	caCert, caKey, err := nca.LoadOrCreate(caDir)
	if err != nil {
		t.Fatalf("CA setup: %v", err)
	}

	rulesPath := writeTestRules(t)
	p := proxy.New(hostsFile,
		proxy.WithRules(rulesPath),
		proxy.WithCA(caCert, caKey),
		proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}),
	)
	srv := httptest.NewServer(p)
	defer srv.Close()

	caPool := x509.NewCertPool()
	caPool.AddCert(caCert)

	proxyURL, _ := url.Parse(srv.URL)
	client := &http.Client{
		Transport: &http.Transport{
			Proxy: http.ProxyURL(proxyURL),
			TLSClientConfig: &tls.Config{
				RootCAs: caPool,
			},
		},
	}

	resp, err := client.Post(backend.URL+"/data", "text/plain", strings.NewReader("clean"))
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		t.Errorf("expected 200, got %d", resp.StatusCode)
	}
}