70c1cc4e
feat: wire middleware into proxy MITM loop
a73x 2026-03-31 05:49
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
diff --git a/proxy/proxy.go b/proxy/proxy.go index 1a6c468..f336224 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -16,6 +16,7 @@ import ( "sync" "github.com/xanderle/nono/ca" "github.com/xanderle/nono/middleware" "github.com/xanderle/nono/scanner" ) @@ -34,6 +35,18 @@ func WithRules(rulesPath string) Option { } } // WithMiddleware returns an Option that loads middleware from the given config path. func WithMiddleware(path string) Option { return func(p *Proxy) { mw, err := middleware.New(path) if err != nil { log.Printf("WARNING: failed to load middleware from %q: %v", path, err) return } p.middleware = mw } } // WithCA returns an Option that enables MITM interception using the given CA cert and key. func WithCA(cert *x509.Certificate, key *ecdsa.PrivateKey) Option { return func(p *Proxy) { @@ -54,6 +67,7 @@ func WithUpstreamTLS(cfg *tls.Config) Option { type Proxy struct { hostsFile string scanner *scanner.Scanner middleware *middleware.Middleware caCert *x509.Certificate caKey *ecdsa.PrivateKey certCache map[string]*tls.Certificate @@ -232,8 +246,29 @@ func (p *Proxy) handleConnectMITM(w http.ResponseWriter, r *http.Request) { return } upstreamResp.Write(tlsConn) upstreamResp.Body.Close() if p.middleware != nil { if rule := p.middleware.Match(host, req.URL.Path); rule != nil { body, err := io.ReadAll(upstreamResp.Body) upstreamResp.Body.Close() if err != nil { log.Printf("ERROR: middleware failed to read response body from %s: %v", host, err) } else { if err := rule.SaveResponse(body); err != nil { log.Printf("ERROR: middleware failed to save response to %s: %v", rule.Dest, err) } else { log.Printf("MIDDLEWARE saved %s%s -> %s", host, req.URL.Path, rule.Dest) } upstreamResp.Body = io.NopCloser(bytes.NewReader(body)) } upstreamResp.Write(tlsConn) } else { upstreamResp.Write(tlsConn) upstreamResp.Body.Close() } } else { upstreamResp.Write(tlsConn) upstreamResp.Body.Close() } if req.Close || upstreamResp.Close { return diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 2746587..673d686 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -3,6 +3,7 @@ package proxy_test import ( "crypto/tls" "crypto/x509" "fmt" "net/http" "net/http/httptest" "net/url" @@ -288,3 +289,72 @@ func TestShouldAllowCleanHTTPSRequest(t *testing.T) { t.Errorf("expected 200, got %d", resp.StatusCode) } } func TestMiddlewareSavesResponseBody(t *testing.T) { responseBody := `{"usage": 100}` backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write([]byte(responseBody)) })) defer backend.Close() backendURL, _ := url.Parse(backend.URL) host := backendURL.Hostname() hostsFile := filepath.Join(t.TempDir(), "approved_hosts") os.WriteFile(hostsFile, []byte(host+"\n"), 0644) destFile := filepath.Join(t.TempDir(), "usage.json") mwPath := filepath.Join(t.TempDir(), "middleware.yaml") os.WriteFile(mwPath, []byte(fmt.Sprintf(`middleware: - match: "%s/data" action: save_response dest: "%s" `, host, destFile)), 0644) caDir := t.TempDir() caCert, caKey, err := nca.LoadOrCreate(caDir) if err != nil { t.Fatalf("CA setup: %v", err) } p := proxy.New(hostsFile, proxy.WithCA(caCert, caKey), proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}), proxy.WithMiddleware(mwPath), ) srv := httptest.NewServer(p) defer srv.Close() caPool := x509.NewCertPool() caPool.AddCert(caCert) proxyURL, _ := url.Parse(srv.URL) client := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyURL(proxyURL), TLSClientConfig: &tls.Config{ RootCAs: caPool, }, }, } resp, err := client.Get(backend.URL + "/data") if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("expected 200, got %d", resp.StatusCode) } got, err := os.ReadFile(destFile) if err != nil { t.Fatalf("dest file not written: %v", err) } if string(got) != responseBody { t.Errorf("expected %q, got %q", responseBody, got) } }