a73x

proxy/proxy.go

Ref:   Size: 7.9 KiB

package proxy

import (
	"bufio"
	"bytes"
	"crypto/ecdsa"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"os"
	"strings"
	"sync"

	"github.com/xanderle/nono/ca"
	"github.com/xanderle/nono/scanner"
)

// Option is a functional option for configuring a Proxy.
type Option func(*Proxy)

// WithRules returns an Option that loads a scanner from the given rules file path.
func WithRules(rulesPath string) Option {
	return func(p *Proxy) {
		s, err := scanner.New(rulesPath)
		if err != nil {
			log.Printf("WARNING: failed to load rules from %q: %v", rulesPath, err)
			return
		}
		p.scanner = s
	}
}

// WithCA returns an Option that enables MITM interception using the given CA cert and key.
func WithCA(cert *x509.Certificate, key *ecdsa.PrivateKey) Option {
	return func(p *Proxy) {
		p.caCert = cert
		p.caKey = key
		p.certCache = make(map[string]*tls.Certificate)
	}
}

// WithUpstreamTLS returns an Option that sets the TLS config used when dialing upstream.
func WithUpstreamTLS(cfg *tls.Config) Option {
	return func(p *Proxy) {
		p.upstreamTLS = cfg
	}
}

// Proxy is an HTTP proxy that only allows connections to approved hosts.
type Proxy struct {
	hostsFile   string
	scanner     *scanner.Scanner
	caCert      *x509.Certificate
	caKey       *ecdsa.PrivateKey
	certCache   map[string]*tls.Certificate
	certMu      sync.Mutex
	upstreamTLS *tls.Config
}

// New creates a new Proxy that checks hosts against the given allowlist file.
func New(hostsFile string, opts ...Option) *Proxy {
	p := &Proxy{hostsFile: hostsFile}
	for _, opt := range opts {
		opt(p)
	}
	return p
}

func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	host := extractHost(r.Host)

	if !p.isApproved(host) {
		log.Printf("DENIED  %s %s", r.Method, r.Host)
		http.Error(w, fmt.Sprintf("host %q not approved. Run: nono allow %s", host, host), http.StatusForbidden)
		return
	}

	if r.Method == http.MethodConnect {
		log.Printf("ALLOWED %s %s", r.Method, r.Host)
		p.handleConnect(w, r)
		return
	}

	p.handleHTTP(w, r)
}

func (p *Proxy) isApproved(host string) bool {
	f, err := os.Open(p.hostsFile)
	if err != nil {
		return false
	}
	defer f.Close()

	scanner := bufio.NewScanner(f)
	for scanner.Scan() {
		line := strings.TrimSpace(scanner.Text())
		if line == host {
			return true
		}
	}
	return false
}

func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
	if p.caCert == nil {
		p.handleConnectTunnel(w, r)
		return
	}
	p.handleConnectMITM(w, r)
}

func (p *Proxy) handleConnectTunnel(w http.ResponseWriter, r *http.Request) {
	targetConn, err := net.Dial("tcp", r.Host)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadGateway)
		return
	}

	hj, ok := w.(http.Hijacker)
	if !ok {
		http.Error(w, "hijacking not supported", http.StatusInternalServerError)
		return
	}

	clientConn, _, err := hj.Hijack()
	if err != nil {
		targetConn.Close()
		return
	}

	clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))

	go io.Copy(targetConn, clientConn)
	io.Copy(clientConn, targetConn)

	clientConn.Close()
	targetConn.Close()
}

func (p *Proxy) handleConnectMITM(w http.ResponseWriter, r *http.Request) {
	hj, ok := w.(http.Hijacker)
	if !ok {
		http.Error(w, "hijacking not supported", http.StatusInternalServerError)
		return
	}

	clientConn, _, err := hj.Hijack()
	if err != nil {
		return
	}

	clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))

	host := extractHost(r.Host)

	leafCert, err := p.getOrCreateLeaf(host)
	if err != nil {
		log.Printf("ERROR: failed to get leaf cert for %s: %v", host, err)
		clientConn.Close()
		return
	}

	tlsConn := tls.Server(clientConn, &tls.Config{
		Certificates: []tls.Certificate{*leafCert},
	})
	if err := tlsConn.Handshake(); err != nil {
		log.Printf("ERROR: TLS handshake with client failed for %s: %v", host, err)
		tlsConn.Close()
		return
	}
	defer tlsConn.Close()

	upstreamCfg := p.upstreamTLS
	if upstreamCfg == nil {
		upstreamCfg = &tls.Config{}
	}
	cfg := upstreamCfg.Clone()
	cfg.ServerName = host

	upstreamConn, err := tls.Dial("tcp", r.Host, cfg)
	if err != nil {
		log.Printf("ERROR: failed to dial upstream %s: %v", r.Host, err)
		return
	}
	defer upstreamConn.Close()

	clientReader := bufio.NewReader(tlsConn)
	upstreamReader := bufio.NewReader(upstreamConn)

	for {
		req, err := http.ReadRequest(clientReader)
		if err != nil {
			if err != io.EOF {
				log.Printf("ERROR: failed to read request from TLS conn for %s: %v", host, err)
			}
			return
		}

		findings := p.scanRequest(req, host)
		if len(findings) > 0 {
			rules := make([]string, 0, len(findings))
			for _, f := range findings {
				rules = append(rules, f.Rule)
			}
			log.Printf("BLOCKED %s %s [%s]", req.Method, r.Host, strings.Join(rules, ", "))
			resp := &http.Response{
				StatusCode: http.StatusForbidden,
				ProtoMajor: 1,
				ProtoMinor: 1,
				Header:     make(http.Header),
				Body:       io.NopCloser(strings.NewReader(fmt.Sprintf("request blocked: contains sensitive data (%s)\n", strings.Join(rules, ", ")))),
			}
			resp.Header.Set("Content-Type", "text/plain")
			resp.Header.Set("Connection", "close")
			resp.Write(tlsConn)
			return
		}

		req.RequestURI = ""
		if err := req.Write(upstreamConn); err != nil {
			log.Printf("ERROR: failed to forward request to %s: %v", r.Host, err)
			return
		}

		upstreamResp, err := http.ReadResponse(upstreamReader, req)
		if err != nil {
			log.Printf("ERROR: failed to read response from %s: %v", r.Host, err)
			return
		}

		upstreamResp.Write(tlsConn)
		upstreamResp.Body.Close()

		if req.Close || upstreamResp.Close {
			return
		}
	}
}

func (p *Proxy) getOrCreateLeaf(host string) (*tls.Certificate, error) {
	p.certMu.Lock()
	defer p.certMu.Unlock()

	if cert, ok := p.certCache[host]; ok {
		return cert, nil
	}

	leaf, err := ca.GenerateLeaf(host, p.caCert, p.caKey)
	if err != nil {
		return nil, err
	}

	p.certCache[host] = &leaf
	return &leaf, nil
}

func (p *Proxy) scanRequest(r *http.Request, host string) []scanner.Finding {
	if p.scanner == nil {
		return nil
	}

	var buf bytes.Buffer

	// Serialize headers as "Key: Value\n" lines
	for k, vv := range r.Header {
		for _, v := range vv {
			fmt.Fprintf(&buf, "%s: %s\n", k, v)
		}
	}

	// Read body if present
	if r.Body != nil {
		body, err := io.ReadAll(r.Body)
		if err != nil {
			log.Printf("WARNING: failed to read request body: %v", err)
			return nil
		}
		buf.Write(body)
		r.Body = io.NopCloser(bytes.NewReader(body))
	}

	return p.scanner.Scan(buf.Bytes(), host)
}

func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
	host := extractHost(r.Host)
	findings := p.scanRequest(r, host)
	if len(findings) > 0 {
		rules := make([]string, 0, len(findings))
		for _, f := range findings {
			rules = append(rules, f.Rule)
		}
		log.Printf("BLOCKED %s %s [%s]", r.Method, r.Host, strings.Join(rules, ", "))
		http.Error(w, fmt.Sprintf("request blocked: contains sensitive data (%s)", strings.Join(rules, ", ")), http.StatusForbidden)
		return
	}

	log.Printf("ALLOWED %s %s", r.Method, r.Host)
	r.RequestURI = ""
	resp, err := http.DefaultTransport.RoundTrip(r)
	if err != nil {
		http.Error(w, err.Error(), http.StatusBadGateway)
		return
	}
	defer resp.Body.Close()

	for k, vv := range resp.Header {
		for _, v := range vv {
			w.Header().Add(k, v)
		}
	}
	w.WriteHeader(resp.StatusCode)
	io.Copy(w, resp.Body)
}

// Allow adds a host to the approved hosts file, deduplicating.
func Allow(hostsFile, host string) error {
	existing := make(map[string]bool)

	if data, err := os.ReadFile(hostsFile); err == nil {
		scanner := bufio.NewScanner(strings.NewReader(string(data)))
		for scanner.Scan() {
			line := strings.TrimSpace(scanner.Text())
			if line != "" {
				existing[line] = true
			}
		}
	}

	if existing[host] {
		return nil
	}

	f, err := os.OpenFile(hostsFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
	if err != nil {
		return err
	}
	defer f.Close()

	_, err = fmt.Fprintln(f, host)
	return err
}

func extractHost(hostport string) string {
	if i := strings.LastIndex(hostport, ":"); i != -1 {
		return hostport[:i]
	}
	return hostport
}