a73x

e6db3821

feat: integrate scanner into HTTP request handling

a73x   2026-03-29 16:26

Add functional options pattern to Proxy, WithRules option to load a scanner,
scanRequest method that checks headers and body, and 403 blocking in handleHTTP.


diff --git a/proxy/proxy.go b/proxy/proxy.go
index a8e630f..f3fc126 100644
--- a/proxy/proxy.go
+++ b/proxy/proxy.go
@@ -2,6 +2,7 @@ package proxy

import (
	"bufio"
	"bytes"
	"fmt"
	"io"
	"log"
@@ -9,16 +10,38 @@ import (
	"net/http"
	"os"
	"strings"

	"github.com/xanderle/nono/scanner"
)

// Option is a functional option for configuring a Proxy.
type Option func(*Proxy)

// WithRules returns an Option that loads a scanner from the given rules file path.
func WithRules(rulesPath string) Option {
	return func(p *Proxy) {
		s, err := scanner.New(rulesPath)
		if err != nil {
			log.Printf("WARNING: failed to load rules from %q: %v", rulesPath, err)
			return
		}
		p.scanner = s
	}
}

// Proxy is an HTTP proxy that only allows connections to approved hosts.
type Proxy struct {
	hostsFile string
	scanner   *scanner.Scanner
}

// New creates a new Proxy that checks hosts against the given allowlist file.
func New(hostsFile string) *Proxy {
	return &Proxy{hostsFile: hostsFile}
func New(hostsFile string, opts ...Option) *Proxy {
	p := &Proxy{hostsFile: hostsFile}
	for _, opt := range opts {
		opt(p)
	}
	return p
}

func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -85,7 +108,45 @@ func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
	targetConn.Close()
}

func (p *Proxy) scanRequest(r *http.Request) []scanner.Finding {
	if p.scanner == nil {
		return nil
	}

	var buf bytes.Buffer

	// Serialize headers as "Key: Value\n" lines
	for k, vv := range r.Header {
		for _, v := range vv {
			fmt.Fprintf(&buf, "%s: %s\n", k, v)
		}
	}

	// Read body if present
	if r.Body != nil {
		body, err := io.ReadAll(r.Body)
		if err == nil {
			buf.Write(body)
			// Restore the body so it can be forwarded
			r.Body = io.NopCloser(bytes.NewReader(body))
		}
	}

	return p.scanner.Scan(buf.Bytes())
}

func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
	findings := p.scanRequest(r)
	if len(findings) > 0 {
		rules := make([]string, 0, len(findings))
		for _, f := range findings {
			rules = append(rules, f.Rule)
		}
		log.Printf("BLOCKED %s %s [%s]", r.Method, r.Host, strings.Join(rules, ", "))
		http.Error(w, fmt.Sprintf("request blocked: contains sensitive data (%s)", strings.Join(rules, ", ")), http.StatusForbidden)
		return
	}

	r.RequestURI = ""
	resp, err := http.DefaultTransport.RoundTrip(r)
	if err != nil {
diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go
index 666ba3d..cff0dd7 100644
--- a/proxy/proxy_test.go
+++ b/proxy/proxy_test.go
@@ -115,3 +115,69 @@ func TestAllowShouldCreateFileIfMissing(t *testing.T) {
		t.Errorf("expected approved_hosts to contain example.com, got %q", string(data))
	}
}

func writeTestRules(t *testing.T) string {
	t.Helper()
	dir := t.TempDir()
	path := filepath.Join(dir, "rules.yaml")
	os.WriteFile(path, []byte("rules:\n  - name: ssh-private-key\n    pattern: \"-----BEGIN RSA PRIVATE KEY-----\"\n  - name: aws-access-key\n    pattern: \"AKIA[0-9A-Z]{16}\"\n"), 0644)
	return path
}

func TestShouldBlockRequestWithSSHKey(t *testing.T) {
	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)
	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(backendURL.Hostname()+"\n"), 0644)

	rulesPath := writeTestRules(t)
	p := proxy.New(hostsFile, proxy.WithRules(rulesPath))
	srv := httptest.NewServer(p)
	defer srv.Close()

	client := newProxyClient(t, srv.URL)

	body := strings.NewReader("data=-----BEGIN RSA PRIVATE KEY-----\nMIIE...")
	resp, err := client.Post(backend.URL+"/upload", "text/plain", body)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusForbidden {
		t.Errorf("expected 403, got %d", resp.StatusCode)
	}
}

func TestShouldAllowCleanRequest(t *testing.T) {
	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)
	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(backendURL.Hostname()+"\n"), 0644)

	rulesPath := writeTestRules(t)
	p := proxy.New(hostsFile, proxy.WithRules(rulesPath))
	srv := httptest.NewServer(p)
	defer srv.Close()

	client := newProxyClient(t, srv.URL)

	body := strings.NewReader("just normal data")
	resp, err := client.Post(backend.URL+"/upload", "text/plain", body)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		t.Errorf("expected 200, got %d", resp.StatusCode)
	}
}