proxy/proxy.go
Ref: Size: 7.9 KiB
package proxy
import (
"bufio"
"bytes"
"crypto/ecdsa"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"sync"
"github.com/xanderle/nono/ca"
"github.com/xanderle/nono/scanner"
)
// Option is a functional option for configuring a Proxy.
type Option func(*Proxy)
// WithRules returns an Option that loads a scanner from the given rules file path.
func WithRules(rulesPath string) Option {
return func(p *Proxy) {
s, err := scanner.New(rulesPath)
if err != nil {
log.Printf("WARNING: failed to load rules from %q: %v", rulesPath, err)
return
}
p.scanner = s
}
}
// WithCA returns an Option that enables MITM interception using the given CA cert and key.
func WithCA(cert *x509.Certificate, key *ecdsa.PrivateKey) Option {
return func(p *Proxy) {
p.caCert = cert
p.caKey = key
p.certCache = make(map[string]*tls.Certificate)
}
}
// WithUpstreamTLS returns an Option that sets the TLS config used when dialing upstream.
func WithUpstreamTLS(cfg *tls.Config) Option {
return func(p *Proxy) {
p.upstreamTLS = cfg
}
}
// Proxy is an HTTP proxy that only allows connections to approved hosts.
type Proxy struct {
hostsFile string
scanner *scanner.Scanner
caCert *x509.Certificate
caKey *ecdsa.PrivateKey
certCache map[string]*tls.Certificate
certMu sync.Mutex
upstreamTLS *tls.Config
}
// New creates a new Proxy that checks hosts against the given allowlist file.
func New(hostsFile string, opts ...Option) *Proxy {
p := &Proxy{hostsFile: hostsFile}
for _, opt := range opts {
opt(p)
}
return p
}
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
host := extractHost(r.Host)
if !p.isApproved(host) {
log.Printf("DENIED %s %s", r.Method, r.Host)
http.Error(w, fmt.Sprintf("host %q not approved. Run: nono allow %s", host, host), http.StatusForbidden)
return
}
if r.Method == http.MethodConnect {
log.Printf("ALLOWED %s %s", r.Method, r.Host)
p.handleConnect(w, r)
return
}
p.handleHTTP(w, r)
}
func (p *Proxy) isApproved(host string) bool {
f, err := os.Open(p.hostsFile)
if err != nil {
return false
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == host {
return true
}
}
return false
}
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
if p.caCert == nil {
p.handleConnectTunnel(w, r)
return
}
p.handleConnectMITM(w, r)
}
func (p *Proxy) handleConnectTunnel(w http.ResponseWriter, r *http.Request) {
targetConn, err := net.Dial("tcp", r.Host)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hj.Hijack()
if err != nil {
targetConn.Close()
return
}
clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
go io.Copy(targetConn, clientConn)
io.Copy(clientConn, targetConn)
clientConn.Close()
targetConn.Close()
}
func (p *Proxy) handleConnectMITM(w http.ResponseWriter, r *http.Request) {
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
return
}
clientConn, _, err := hj.Hijack()
if err != nil {
return
}
clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
host := extractHost(r.Host)
leafCert, err := p.getOrCreateLeaf(host)
if err != nil {
log.Printf("ERROR: failed to get leaf cert for %s: %v", host, err)
clientConn.Close()
return
}
tlsConn := tls.Server(clientConn, &tls.Config{
Certificates: []tls.Certificate{*leafCert},
})
if err := tlsConn.Handshake(); err != nil {
log.Printf("ERROR: TLS handshake with client failed for %s: %v", host, err)
tlsConn.Close()
return
}
defer tlsConn.Close()
upstreamCfg := p.upstreamTLS
if upstreamCfg == nil {
upstreamCfg = &tls.Config{}
}
cfg := upstreamCfg.Clone()
cfg.ServerName = host
upstreamConn, err := tls.Dial("tcp", r.Host, cfg)
if err != nil {
log.Printf("ERROR: failed to dial upstream %s: %v", r.Host, err)
return
}
defer upstreamConn.Close()
clientReader := bufio.NewReader(tlsConn)
upstreamReader := bufio.NewReader(upstreamConn)
for {
req, err := http.ReadRequest(clientReader)
if err != nil {
if err != io.EOF {
log.Printf("ERROR: failed to read request from TLS conn for %s: %v", host, err)
}
return
}
findings := p.scanRequest(req, host)
if len(findings) > 0 {
rules := make([]string, 0, len(findings))
for _, f := range findings {
rules = append(rules, f.Rule)
}
log.Printf("BLOCKED %s %s [%s]", req.Method, r.Host, strings.Join(rules, ", "))
resp := &http.Response{
StatusCode: http.StatusForbidden,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(fmt.Sprintf("request blocked: contains sensitive data (%s)\n", strings.Join(rules, ", ")))),
}
resp.Header.Set("Content-Type", "text/plain")
resp.Header.Set("Connection", "close")
resp.Write(tlsConn)
return
}
req.RequestURI = ""
if err := req.Write(upstreamConn); err != nil {
log.Printf("ERROR: failed to forward request to %s: %v", r.Host, err)
return
}
upstreamResp, err := http.ReadResponse(upstreamReader, req)
if err != nil {
log.Printf("ERROR: failed to read response from %s: %v", r.Host, err)
return
}
upstreamResp.Write(tlsConn)
upstreamResp.Body.Close()
if req.Close || upstreamResp.Close {
return
}
}
}
func (p *Proxy) getOrCreateLeaf(host string) (*tls.Certificate, error) {
p.certMu.Lock()
defer p.certMu.Unlock()
if cert, ok := p.certCache[host]; ok {
return cert, nil
}
leaf, err := ca.GenerateLeaf(host, p.caCert, p.caKey)
if err != nil {
return nil, err
}
p.certCache[host] = &leaf
return &leaf, nil
}
func (p *Proxy) scanRequest(r *http.Request, host string) []scanner.Finding {
if p.scanner == nil {
return nil
}
var buf bytes.Buffer
// Serialize headers as "Key: Value\n" lines
for k, vv := range r.Header {
for _, v := range vv {
fmt.Fprintf(&buf, "%s: %s\n", k, v)
}
}
// Read body if present
if r.Body != nil {
body, err := io.ReadAll(r.Body)
if err != nil {
log.Printf("WARNING: failed to read request body: %v", err)
return nil
}
buf.Write(body)
r.Body = io.NopCloser(bytes.NewReader(body))
}
return p.scanner.Scan(buf.Bytes(), host)
}
func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
host := extractHost(r.Host)
findings := p.scanRequest(r, host)
if len(findings) > 0 {
rules := make([]string, 0, len(findings))
for _, f := range findings {
rules = append(rules, f.Rule)
}
log.Printf("BLOCKED %s %s [%s]", r.Method, r.Host, strings.Join(rules, ", "))
http.Error(w, fmt.Sprintf("request blocked: contains sensitive data (%s)", strings.Join(rules, ", ")), http.StatusForbidden)
return
}
log.Printf("ALLOWED %s %s", r.Method, r.Host)
r.RequestURI = ""
resp, err := http.DefaultTransport.RoundTrip(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
defer resp.Body.Close()
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
// Allow adds a host to the approved hosts file, deduplicating.
func Allow(hostsFile, host string) error {
existing := make(map[string]bool)
if data, err := os.ReadFile(hostsFile); err == nil {
scanner := bufio.NewScanner(strings.NewReader(string(data)))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line != "" {
existing[line] = true
}
}
}
if existing[host] {
return nil
}
f, err := os.OpenFile(hostsFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer f.Close()
_, err = fmt.Fprintln(f, host)
return err
}
func extractHost(hostport string) string {
if i := strings.LastIndex(hostport, ":"); i != -1 {
return hostport[:i]
}
return hostport
}