bd347edc
feat: add MITM CONNECT handling with body scanning
a73x 2026-03-29 16:38
Add WithCA and WithUpstreamTLS options; replace blind tunnel with TLS interception when a CA is configured, scanning decrypted request bodies. Also extend ca.GenerateLeaf to support IP SAN certificates.
diff --git a/ca/ca.go b/ca/ca.go index 90b0b2c..a1cb414 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -10,6 +10,7 @@ import ( "encoding/pem" "errors" "math/big" "net" "os" "path/filepath" "time" @@ -138,7 +139,6 @@ func GenerateLeaf(host string, caCert *x509.Certificate, caKey *ecdsa.PrivateKey Subject: pkix.Name{ CommonName: host, }, DNSNames: []string{host}, NotBefore: time.Now().Add(-time.Minute), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature, @@ -147,6 +147,12 @@ func GenerateLeaf(host string, caCert *x509.Certificate, caKey *ecdsa.PrivateKey }, } if ip := net.ParseIP(host); ip != nil { template.IPAddresses = []net.IP{ip} } else { template.DNSNames = []string{host} } certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, &key.PublicKey, caKey) if err != nil { return tls.Certificate{}, err diff --git a/proxy/proxy.go b/proxy/proxy.go index 616636e..109d2a7 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -3,6 +3,9 @@ package proxy import ( "bufio" "bytes" "crypto/ecdsa" "crypto/tls" "crypto/x509" "fmt" "io" "log" @@ -10,7 +13,9 @@ import ( "net/http" "os" "strings" "sync" "github.com/xanderle/nono/ca" "github.com/xanderle/nono/scanner" ) @@ -29,10 +34,31 @@ func WithRules(rulesPath string) Option { } } // WithCA returns an Option that enables MITM interception using the given CA cert and key. func WithCA(cert *x509.Certificate, key *ecdsa.PrivateKey) Option { return func(p *Proxy) { p.caCert = cert p.caKey = key p.certCache = make(map[string]*tls.Certificate) } } // WithUpstreamTLS returns an Option that sets the TLS config used when dialing upstream. func WithUpstreamTLS(cfg *tls.Config) Option { return func(p *Proxy) { p.upstreamTLS = cfg } } // Proxy is an HTTP proxy that only allows connections to approved hosts. type Proxy struct { hostsFile string scanner *scanner.Scanner hostsFile string scanner *scanner.Scanner caCert *x509.Certificate caKey *ecdsa.PrivateKey certCache map[string]*tls.Certificate certMu sync.Mutex upstreamTLS *tls.Config } // New creates a new Proxy that checks hosts against the given allowlist file. @@ -80,6 +106,14 @@ func (p *Proxy) isApproved(host string) bool { } func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) { if p.caCert == nil { p.handleConnectTunnel(w, r) return } p.handleConnectMITM(w, r) } func (p *Proxy) handleConnectTunnel(w http.ResponseWriter, r *http.Request) { targetConn, err := net.Dial("tcp", r.Host) if err != nil { http.Error(w, err.Error(), http.StatusBadGateway) @@ -107,6 +141,119 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) { targetConn.Close() } func (p *Proxy) handleConnectMITM(w http.ResponseWriter, r *http.Request) { hj, ok := w.(http.Hijacker) if !ok { http.Error(w, "hijacking not supported", http.StatusInternalServerError) return } clientConn, _, err := hj.Hijack() if err != nil { return } clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) host := extractHost(r.Host) leafCert, err := p.getOrCreateLeaf(host) if err != nil { log.Printf("ERROR: failed to get leaf cert for %s: %v", host, err) clientConn.Close() return } tlsConn := tls.Server(clientConn, &tls.Config{ Certificates: []tls.Certificate{*leafCert}, }) if err := tlsConn.Handshake(); err != nil { log.Printf("ERROR: TLS handshake with client failed for %s: %v", host, err) tlsConn.Close() return } defer tlsConn.Close() req, err := http.ReadRequest(bufio.NewReader(tlsConn)) if err != nil { log.Printf("ERROR: failed to read request from TLS conn for %s: %v", host, err) return } findings := p.scanRequest(req) if len(findings) > 0 { rules := make([]string, 0, len(findings)) for _, f := range findings { rules = append(rules, f.Rule) } log.Printf("BLOCKED %s %s [%s]", req.Method, r.Host, strings.Join(rules, ", ")) resp := &http.Response{ StatusCode: http.StatusForbidden, ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(fmt.Sprintf("request blocked: contains sensitive data (%s)\n", strings.Join(rules, ", ")))), } resp.Header.Set("Content-Type", "text/plain") resp.Write(tlsConn) return } upstreamCfg := p.upstreamTLS if upstreamCfg == nil { upstreamCfg = &tls.Config{} } cfg := upstreamCfg.Clone() cfg.ServerName = host upstreamConn, err := tls.Dial("tcp", r.Host, cfg) if err != nil { log.Printf("ERROR: failed to dial upstream %s: %v", r.Host, err) resp := &http.Response{ StatusCode: http.StatusBadGateway, ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(err.Error()+"\n")), } resp.Write(tlsConn) return } defer upstreamConn.Close() req.RequestURI = "" if err := req.Write(upstreamConn); err != nil { log.Printf("ERROR: failed to forward request to %s: %v", r.Host, err) return } upstreamResp, err := http.ReadResponse(bufio.NewReader(upstreamConn), req) if err != nil { log.Printf("ERROR: failed to read response from %s: %v", r.Host, err) return } defer upstreamResp.Body.Close() upstreamResp.Write(tlsConn) } func (p *Proxy) getOrCreateLeaf(host string) (*tls.Certificate, error) { p.certMu.Lock() defer p.certMu.Unlock() if cert, ok := p.certCache[host]; ok { return cert, nil } leaf, err := ca.GenerateLeaf(host, p.caCert, p.caKey) if err != nil { return nil, err } p.certCache[host] = &leaf return &leaf, nil } func (p *Proxy) scanRequest(r *http.Request) []scanner.Finding { if p.scanner == nil { return nil diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index cff0dd7..2746587 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1,6 +1,8 @@ package proxy_test import ( "crypto/tls" "crypto/x509" "net/http" "net/http/httptest" "net/url" @@ -9,6 +11,7 @@ import ( "strings" "testing" nca "github.com/xanderle/nono/ca" "github.com/xanderle/nono/proxy" ) @@ -181,3 +184,107 @@ func TestShouldAllowCleanRequest(t *testing.T) { t.Errorf("expected 200, got %d", resp.StatusCode) } } func TestShouldBlockHTTPSRequestWithAWSKey(t *testing.T) { backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) defer backend.Close() backendURL, _ := url.Parse(backend.URL) host := backendURL.Hostname() hostsFile := filepath.Join(t.TempDir(), "approved_hosts") os.WriteFile(hostsFile, []byte(host+"\n"), 0644) caDir := t.TempDir() caCert, caKey, err := nca.LoadOrCreate(caDir) if err != nil { t.Fatalf("CA setup: %v", err) } rulesPath := writeTestRules(t) p := proxy.New(hostsFile, proxy.WithRules(rulesPath), proxy.WithCA(caCert, caKey), proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}), ) srv := httptest.NewServer(p) defer srv.Close() caPool := x509.NewCertPool() caPool.AddCert(caCert) proxyURL, _ := url.Parse(srv.URL) client := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), TLSClientConfig: &tls.Config{ RootCAs: caPool, }, }, } body := strings.NewReader("key=AKIAIOSFODNN7EXAMPLE") resp, err := client.Post(backend.URL+"/upload", "text/plain", body) if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusForbidden { t.Errorf("expected 403, got %d", resp.StatusCode) } } func TestShouldAllowCleanHTTPSRequest(t *testing.T) { backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) })) defer backend.Close() backendURL, _ := url.Parse(backend.URL) host := backendURL.Hostname() hostsFile := filepath.Join(t.TempDir(), "approved_hosts") os.WriteFile(hostsFile, []byte(host+"\n"), 0644) caDir := t.TempDir() caCert, caKey, err := nca.LoadOrCreate(caDir) if err != nil { t.Fatalf("CA setup: %v", err) } rulesPath := writeTestRules(t) p := proxy.New(hostsFile, proxy.WithRules(rulesPath), proxy.WithCA(caCert, caKey), proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}), ) srv := httptest.NewServer(p) defer srv.Close() caPool := x509.NewCertPool() caPool.AddCert(caCert) proxyURL, _ := url.Parse(srv.URL) client := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), TLSClientConfig: &tls.Config{ RootCAs: caPool, }, }, } resp, err := client.Post(backend.URL+"/data", "text/plain", strings.NewReader("clean")) if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected 200, got %d", resp.StatusCode) } }