a73x

70c1cc4e

feat: wire middleware into proxy MITM loop

a73x   2026-03-31 05:49

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

diff --git a/proxy/proxy.go b/proxy/proxy.go
index 1a6c468..f336224 100644
--- a/proxy/proxy.go
+++ b/proxy/proxy.go
@@ -16,6 +16,7 @@ import (
	"sync"

	"github.com/xanderle/nono/ca"
	"github.com/xanderle/nono/middleware"
	"github.com/xanderle/nono/scanner"
)

@@ -34,6 +35,18 @@ func WithRules(rulesPath string) Option {
	}
}

// WithMiddleware returns an Option that loads middleware from the given config path.
func WithMiddleware(path string) Option {
	return func(p *Proxy) {
		mw, err := middleware.New(path)
		if err != nil {
			log.Printf("WARNING: failed to load middleware from %q: %v", path, err)
			return
		}
		p.middleware = mw
	}
}

// WithCA returns an Option that enables MITM interception using the given CA cert and key.
func WithCA(cert *x509.Certificate, key *ecdsa.PrivateKey) Option {
	return func(p *Proxy) {
@@ -54,6 +67,7 @@ func WithUpstreamTLS(cfg *tls.Config) Option {
type Proxy struct {
	hostsFile   string
	scanner     *scanner.Scanner
	middleware  *middleware.Middleware
	caCert      *x509.Certificate
	caKey       *ecdsa.PrivateKey
	certCache   map[string]*tls.Certificate
@@ -232,8 +246,29 @@ func (p *Proxy) handleConnectMITM(w http.ResponseWriter, r *http.Request) {
			return
		}

		upstreamResp.Write(tlsConn)
		upstreamResp.Body.Close()
		if p.middleware != nil {
			if rule := p.middleware.Match(host, req.URL.Path); rule != nil {
				body, err := io.ReadAll(upstreamResp.Body)
				upstreamResp.Body.Close()
				if err != nil {
					log.Printf("ERROR: middleware failed to read response body from %s: %v", host, err)
				} else {
					if err := rule.SaveResponse(body); err != nil {
						log.Printf("ERROR: middleware failed to save response to %s: %v", rule.Dest, err)
					} else {
						log.Printf("MIDDLEWARE saved %s%s -> %s", host, req.URL.Path, rule.Dest)
					}
					upstreamResp.Body = io.NopCloser(bytes.NewReader(body))
				}
				upstreamResp.Write(tlsConn)
			} else {
				upstreamResp.Write(tlsConn)
				upstreamResp.Body.Close()
			}
		} else {
			upstreamResp.Write(tlsConn)
			upstreamResp.Body.Close()
		}

		if req.Close || upstreamResp.Close {
			return
diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go
index 2746587..673d686 100644
--- a/proxy/proxy_test.go
+++ b/proxy/proxy_test.go
@@ -3,6 +3,7 @@ package proxy_test
import (
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"net/http"
	"net/http/httptest"
	"net/url"
@@ -288,3 +289,72 @@ func TestShouldAllowCleanHTTPSRequest(t *testing.T) {
		t.Errorf("expected 200, got %d", resp.StatusCode)
	}
}

func TestMiddlewareSavesResponseBody(t *testing.T) {
	responseBody := `{"usage": 100}`
	backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "application/json")
		w.WriteHeader(http.StatusOK)
		w.Write([]byte(responseBody))
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)
	host := backendURL.Hostname()

	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(host+"\n"), 0644)

	destFile := filepath.Join(t.TempDir(), "usage.json")
	mwPath := filepath.Join(t.TempDir(), "middleware.yaml")
	os.WriteFile(mwPath, []byte(fmt.Sprintf(`middleware:
  - match: "%s/data"
    action: save_response
    dest: "%s"
`, host, destFile)), 0644)

	caDir := t.TempDir()
	caCert, caKey, err := nca.LoadOrCreate(caDir)
	if err != nil {
		t.Fatalf("CA setup: %v", err)
	}

	p := proxy.New(hostsFile,
		proxy.WithCA(caCert, caKey),
		proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}),
		proxy.WithMiddleware(mwPath),
	)
	srv := httptest.NewServer(p)
	defer srv.Close()

	caPool := x509.NewCertPool()
	caPool.AddCert(caCert)

	proxyURL, _ := url.Parse(srv.URL)
	client := &http.Client{
		Transport: &http.Transport{
			Proxy: http.ProxyURL(proxyURL),
			TLSClientConfig: &tls.Config{
				RootCAs: caPool,
			},
		},
	}

	resp, err := client.Get(backend.URL + "/data")
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		t.Errorf("expected 200, got %d", resp.StatusCode)
	}

	got, err := os.ReadFile(destFile)
	if err != nil {
		t.Fatalf("dest file not written: %v", err)
	}
	if string(got) != responseBody {
		t.Errorf("expected %q, got %q", responseBody, got)
	}
}