a73x

proxy/proxy_test.go

Ref:   Size: 7.6 KiB

package proxy_test

import (
	"crypto/tls"
	"crypto/x509"
	"net/http"
	"net/http/httptest"
	"net/url"
	"os"
	"path/filepath"
	"strings"
	"testing"

	nca "github.com/xanderle/nono/ca"
	"github.com/xanderle/nono/proxy"
)

func newProxyClient(t *testing.T, proxyURL string) *http.Client {
	t.Helper()
	return &http.Client{
		Transport: &http.Transport{
			Proxy: func(*http.Request) (*url.URL, error) {
				return url.Parse(proxyURL)
			},
		},
	}
}

func TestShouldDenyUnapprovedHost(t *testing.T) {
	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(""), 0644)

	p := proxy.New(hostsFile)
	srv := httptest.NewServer(p)
	defer srv.Close()

	client := newProxyClient(t, srv.URL)

	resp, err := client.Get("http://example.com/")
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusForbidden {
		t.Errorf("expected 403 Forbidden, got %d", resp.StatusCode)
	}
}

func TestShouldAllowApprovedHost(t *testing.T) {
	// Backend server simulating the target
	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		w.Write([]byte("hello from backend"))
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)

	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(backendURL.Hostname()+"\n"), 0644)

	p := proxy.New(hostsFile)
	srv := httptest.NewServer(p)
	defer srv.Close()

	client := newProxyClient(t, srv.URL)

	resp, err := client.Get(backend.URL + "/test")
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		t.Errorf("expected 200 OK, got %d", resp.StatusCode)
	}
}

func TestAllowShouldAddHostToFile(t *testing.T) {
	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(""), 0644)

	if err := proxy.Allow(hostsFile, "example.com"); err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	data, _ := os.ReadFile(hostsFile)
	if !strings.Contains(string(data), "example.com") {
		t.Errorf("expected approved_hosts to contain example.com, got %q", string(data))
	}
}

func TestAllowShouldDeduplicateHosts(t *testing.T) {
	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte("example.com\n"), 0644)

	if err := proxy.Allow(hostsFile, "example.com"); err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	data, _ := os.ReadFile(hostsFile)
	count := strings.Count(string(data), "example.com")
	if count != 1 {
		t.Errorf("expected 1 occurrence, got %d in %q", count, string(data))
	}
}

func TestAllowShouldCreateFileIfMissing(t *testing.T) {
	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")

	if err := proxy.Allow(hostsFile, "example.com"); err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	data, _ := os.ReadFile(hostsFile)
	if !strings.Contains(string(data), "example.com") {
		t.Errorf("expected approved_hosts to contain example.com, got %q", string(data))
	}
}

func writeTestRules(t *testing.T) string {
	t.Helper()
	dir := t.TempDir()
	path := filepath.Join(dir, "rules.yaml")
	os.WriteFile(path, []byte("rules:\n  - name: ssh-private-key\n    pattern: \"-----BEGIN RSA PRIVATE KEY-----\"\n  - name: aws-access-key\n    pattern: \"AKIA[0-9A-Z]{16}\"\n"), 0644)
	return path
}

func TestShouldBlockRequestWithSSHKey(t *testing.T) {
	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)
	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(backendURL.Hostname()+"\n"), 0644)

	rulesPath := writeTestRules(t)
	p := proxy.New(hostsFile, proxy.WithRules(rulesPath))
	srv := httptest.NewServer(p)
	defer srv.Close()

	client := newProxyClient(t, srv.URL)

	body := strings.NewReader("data=-----BEGIN RSA PRIVATE KEY-----\nMIIE...")
	resp, err := client.Post(backend.URL+"/upload", "text/plain", body)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusForbidden {
		t.Errorf("expected 403, got %d", resp.StatusCode)
	}
}

func TestShouldAllowCleanRequest(t *testing.T) {
	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)
	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(backendURL.Hostname()+"\n"), 0644)

	rulesPath := writeTestRules(t)
	p := proxy.New(hostsFile, proxy.WithRules(rulesPath))
	srv := httptest.NewServer(p)
	defer srv.Close()

	client := newProxyClient(t, srv.URL)

	body := strings.NewReader("just normal data")
	resp, err := client.Post(backend.URL+"/upload", "text/plain", body)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		t.Errorf("expected 200, got %d", resp.StatusCode)
	}
}

func TestShouldBlockHTTPSRequestWithAWSKey(t *testing.T) {
	backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)
	host := backendURL.Hostname()

	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(host+"\n"), 0644)

	caDir := t.TempDir()
	caCert, caKey, err := nca.LoadOrCreate(caDir)
	if err != nil {
		t.Fatalf("CA setup: %v", err)
	}

	rulesPath := writeTestRules(t)
	p := proxy.New(hostsFile,
		proxy.WithRules(rulesPath),
		proxy.WithCA(caCert, caKey),
		proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}),
	)
	srv := httptest.NewServer(p)
	defer srv.Close()

	caPool := x509.NewCertPool()
	caPool.AddCert(caCert)

	proxyURL, _ := url.Parse(srv.URL)
	client := &http.Client{
		Transport: &http.Transport{
			Proxy: http.ProxyURL(proxyURL),
			TLSClientConfig: &tls.Config{
				RootCAs: caPool,
			},
		},
	}

	body := strings.NewReader("key=AKIAIOSFODNN7EXAMPLE")
	resp, err := client.Post(backend.URL+"/upload", "text/plain", body)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusForbidden {
		t.Errorf("expected 403, got %d", resp.StatusCode)
	}
}

func TestShouldAllowCleanHTTPSRequest(t *testing.T) {
	backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.WriteHeader(http.StatusOK)
		w.Write([]byte("ok"))
	}))
	defer backend.Close()

	backendURL, _ := url.Parse(backend.URL)
	host := backendURL.Hostname()

	hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
	os.WriteFile(hostsFile, []byte(host+"\n"), 0644)

	caDir := t.TempDir()
	caCert, caKey, err := nca.LoadOrCreate(caDir)
	if err != nil {
		t.Fatalf("CA setup: %v", err)
	}

	rulesPath := writeTestRules(t)
	p := proxy.New(hostsFile,
		proxy.WithRules(rulesPath),
		proxy.WithCA(caCert, caKey),
		proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}),
	)
	srv := httptest.NewServer(p)
	defer srv.Close()

	caPool := x509.NewCertPool()
	caPool.AddCert(caCert)

	proxyURL, _ := url.Parse(srv.URL)
	client := &http.Client{
		Transport: &http.Transport{
			Proxy: http.ProxyURL(proxyURL),
			TLSClientConfig: &tls.Config{
				RootCAs: caPool,
			},
		},
	}

	resp, err := client.Post(backend.URL+"/data", "text/plain", strings.NewReader("clean"))
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		t.Errorf("expected 200, got %d", resp.StatusCode)
	}
}