proxy/proxy_test.go
Ref: Size: 7.6 KiB
package proxy_test
import (
"crypto/tls"
"crypto/x509"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
nca "github.com/xanderle/nono/ca"
"github.com/xanderle/nono/proxy"
)
func newProxyClient(t *testing.T, proxyURL string) *http.Client {
t.Helper()
return &http.Client{
Transport: &http.Transport{
Proxy: func(*http.Request) (*url.URL, error) {
return url.Parse(proxyURL)
},
},
}
}
func TestShouldDenyUnapprovedHost(t *testing.T) {
hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
os.WriteFile(hostsFile, []byte(""), 0644)
p := proxy.New(hostsFile)
srv := httptest.NewServer(p)
defer srv.Close()
client := newProxyClient(t, srv.URL)
resp, err := client.Get("http://example.com/")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Errorf("expected 403 Forbidden, got %d", resp.StatusCode)
}
}
func TestShouldAllowApprovedHost(t *testing.T) {
// Backend server simulating the target
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello from backend"))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
os.WriteFile(hostsFile, []byte(backendURL.Hostname()+"\n"), 0644)
p := proxy.New(hostsFile)
srv := httptest.NewServer(p)
defer srv.Close()
client := newProxyClient(t, srv.URL)
resp, err := client.Get(backend.URL + "/test")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected 200 OK, got %d", resp.StatusCode)
}
}
func TestAllowShouldAddHostToFile(t *testing.T) {
hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
os.WriteFile(hostsFile, []byte(""), 0644)
if err := proxy.Allow(hostsFile, "example.com"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
data, _ := os.ReadFile(hostsFile)
if !strings.Contains(string(data), "example.com") {
t.Errorf("expected approved_hosts to contain example.com, got %q", string(data))
}
}
func TestAllowShouldDeduplicateHosts(t *testing.T) {
hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
os.WriteFile(hostsFile, []byte("example.com\n"), 0644)
if err := proxy.Allow(hostsFile, "example.com"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
data, _ := os.ReadFile(hostsFile)
count := strings.Count(string(data), "example.com")
if count != 1 {
t.Errorf("expected 1 occurrence, got %d in %q", count, string(data))
}
}
func TestAllowShouldCreateFileIfMissing(t *testing.T) {
hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
if err := proxy.Allow(hostsFile, "example.com"); err != nil {
t.Fatalf("unexpected error: %v", err)
}
data, _ := os.ReadFile(hostsFile)
if !strings.Contains(string(data), "example.com") {
t.Errorf("expected approved_hosts to contain example.com, got %q", string(data))
}
}
func writeTestRules(t *testing.T) string {
t.Helper()
dir := t.TempDir()
path := filepath.Join(dir, "rules.yaml")
os.WriteFile(path, []byte("rules:\n - name: ssh-private-key\n pattern: \"-----BEGIN RSA PRIVATE KEY-----\"\n - name: aws-access-key\n pattern: \"AKIA[0-9A-Z]{16}\"\n"), 0644)
return path
}
func TestShouldBlockRequestWithSSHKey(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
os.WriteFile(hostsFile, []byte(backendURL.Hostname()+"\n"), 0644)
rulesPath := writeTestRules(t)
p := proxy.New(hostsFile, proxy.WithRules(rulesPath))
srv := httptest.NewServer(p)
defer srv.Close()
client := newProxyClient(t, srv.URL)
body := strings.NewReader("data=-----BEGIN RSA PRIVATE KEY-----\nMIIE...")
resp, err := client.Post(backend.URL+"/upload", "text/plain", body)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Errorf("expected 403, got %d", resp.StatusCode)
}
}
func TestShouldAllowCleanRequest(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
os.WriteFile(hostsFile, []byte(backendURL.Hostname()+"\n"), 0644)
rulesPath := writeTestRules(t)
p := proxy.New(hostsFile, proxy.WithRules(rulesPath))
srv := httptest.NewServer(p)
defer srv.Close()
client := newProxyClient(t, srv.URL)
body := strings.NewReader("just normal data")
resp, err := client.Post(backend.URL+"/upload", "text/plain", body)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
}
func TestShouldBlockHTTPSRequestWithAWSKey(t *testing.T) {
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
host := backendURL.Hostname()
hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
os.WriteFile(hostsFile, []byte(host+"\n"), 0644)
caDir := t.TempDir()
caCert, caKey, err := nca.LoadOrCreate(caDir)
if err != nil {
t.Fatalf("CA setup: %v", err)
}
rulesPath := writeTestRules(t)
p := proxy.New(hostsFile,
proxy.WithRules(rulesPath),
proxy.WithCA(caCert, caKey),
proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}),
)
srv := httptest.NewServer(p)
defer srv.Close()
caPool := x509.NewCertPool()
caPool.AddCert(caCert)
proxyURL, _ := url.Parse(srv.URL)
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
TLSClientConfig: &tls.Config{
RootCAs: caPool,
},
},
}
body := strings.NewReader("key=AKIAIOSFODNN7EXAMPLE")
resp, err := client.Post(backend.URL+"/upload", "text/plain", body)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Errorf("expected 403, got %d", resp.StatusCode)
}
}
func TestShouldAllowCleanHTTPSRequest(t *testing.T) {
backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
host := backendURL.Hostname()
hostsFile := filepath.Join(t.TempDir(), "approved_hosts")
os.WriteFile(hostsFile, []byte(host+"\n"), 0644)
caDir := t.TempDir()
caCert, caKey, err := nca.LoadOrCreate(caDir)
if err != nil {
t.Fatalf("CA setup: %v", err)
}
rulesPath := writeTestRules(t)
p := proxy.New(hostsFile,
proxy.WithRules(rulesPath),
proxy.WithCA(caCert, caKey),
proxy.WithUpstreamTLS(&tls.Config{InsecureSkipVerify: true}),
)
srv := httptest.NewServer(p)
defer srv.Close()
caPool := x509.NewCertPool()
caPool.AddCert(caCert)
proxyURL, _ := url.Parse(srv.URL)
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
TLSClientConfig: &tls.Config{
RootCAs: caPool,
},
},
}
resp, err := client.Post(backend.URL+"/data", "text/plain", strings.NewReader("clean"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected 200, got %d", resp.StatusCode)
}
}