a73x

scanner/scanner_test.go

Ref:   Size: 4.7 KiB

package scanner_test

import (
	"os"
	"path/filepath"
	"testing"

	"github.com/xanderle/nono/scanner"
)

// Task 3: Load Rules from YAML

func TestNewScanner_LoadsRulesFromYAML(t *testing.T) {
	dir := t.TempDir()
	rulesPath := filepath.Join(dir, "rules.yaml")
	os.WriteFile(rulesPath, []byte("rules:\n  - name: ssh-key\n    pattern: \"-----BEGIN RSA PRIVATE KEY-----\"\n  - name: aws-key\n    pattern: \"AKIA[0-9A-Z]{16}\"\n"), 0644)

	s, err := scanner.New(rulesPath)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	if s.RuleCount() != 2 {
		t.Errorf("expected 2 rules, got %d", s.RuleCount())
	}
}

func TestNewScanner_RejectsInvalidRegex(t *testing.T) {
	dir := t.TempDir()
	rulesPath := filepath.Join(dir, "rules.yaml")
	os.WriteFile(rulesPath, []byte("rules:\n  - name: bad-rule\n    pattern: \"[invalid\"\n"), 0644)

	_, err := scanner.New(rulesPath)
	if err == nil {
		t.Fatal("expected error for invalid regex")
	}
}

// Task 4: Scan for Sensitive Patterns

func writeRules(t *testing.T, content string) string {
	t.Helper()
	dir := t.TempDir()
	path := filepath.Join(dir, "rules.yaml")
	os.WriteFile(path, []byte(content), 0644)
	return path
}

func TestScan_DetectsSSHPrivateKey(t *testing.T) {
	path := writeRules(t, "rules:\n  - name: ssh-private-key\n    pattern: \"-----BEGIN (OPENSSH|RSA|DSA|EC|ED25519) PRIVATE KEY-----\"\n")
	s, _ := scanner.New(path)
	findings := s.Scan([]byte("some data\n-----BEGIN RSA PRIVATE KEY-----\nMIIE..."), "")
	if len(findings) != 1 {
		t.Fatalf("expected 1 finding, got %d", len(findings))
	}
	if findings[0].Rule != "ssh-private-key" {
		t.Errorf("expected rule 'ssh-private-key', got %q", findings[0].Rule)
	}
}

func TestScan_DetectsAWSKey(t *testing.T) {
	path := writeRules(t, "rules:\n  - name: aws-access-key\n    pattern: \"AKIA[0-9A-Z]{16}\"\n")
	s, _ := scanner.New(path)
	findings := s.Scan([]byte("{\"key\": \"AKIAIOSFODNN7EXAMPLE\"}"), "")
	if len(findings) != 1 {
		t.Fatalf("expected 1 finding, got %d", len(findings))
	}
	if findings[0].Rule != "aws-access-key" {
		t.Errorf("expected rule 'aws-access-key', got %q", findings[0].Rule)
	}
}

func TestScan_ReturnsMultipleFindings(t *testing.T) {
	path := writeRules(t, "rules:\n  - name: ssh-private-key\n    pattern: \"-----BEGIN RSA PRIVATE KEY-----\"\n  - name: aws-access-key\n    pattern: \"AKIA[0-9A-Z]{16}\"\n")
	s, _ := scanner.New(path)
	body := []byte("-----BEGIN RSA PRIVATE KEY-----\nkey\nAKIAIOSFODNN7EXAMPLE")
	findings := s.Scan(body, "")
	if len(findings) != 2 {
		t.Fatalf("expected 2 findings, got %d", len(findings))
	}
}

func TestScan_ReturnsEmptyForCleanBody(t *testing.T) {
	path := writeRules(t, "rules:\n  - name: ssh-private-key\n    pattern: \"-----BEGIN RSA PRIVATE KEY-----\"\n")
	s, _ := scanner.New(path)
	findings := s.Scan([]byte("just some normal POST data"), "")
	if len(findings) != 0 {
		t.Errorf("expected 0 findings, got %d", len(findings))
	}
}

func TestScan_TruncatesMatchSnippet(t *testing.T) {
	path := writeRules(t, "rules:\n  - name: ssh-private-key\n    pattern: \"-----BEGIN RSA PRIVATE KEY-----\"\n")
	s, _ := scanner.New(path)
	findings := s.Scan([]byte("-----BEGIN RSA PRIVATE KEY-----"), "")
	if len(findings) != 1 {
		t.Fatalf("expected 1 finding, got %d", len(findings))
	}
	if len(findings[0].Match) > 40 {
		t.Errorf("expected match snippet to be truncated, got %d chars", len(findings[0].Match))
	}
}

// Task 5: Default Rules File

func TestWriteDefaultRules_CreatesFile(t *testing.T) {
	dir := t.TempDir()
	path := filepath.Join(dir, "rules.yaml")
	err := scanner.WriteDefaultRules(path)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	s, err := scanner.New(path)
	if err != nil {
		t.Fatalf("failed to load default rules: %v", err)
	}
	if s.RuleCount() < 9 {
		t.Errorf("expected at least 9 default rules, got %d", s.RuleCount())
	}
}

func TestWriteDefaultRules_DoesNotOverwrite(t *testing.T) {
	dir := t.TempDir()
	path := filepath.Join(dir, "rules.yaml")
	os.WriteFile(path, []byte("rules:\n  - name: custom\n    pattern: \"custom\"\n"), 0644)
	err := scanner.WriteDefaultRules(path)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	s, _ := scanner.New(path)
	if s.RuleCount() != 1 {
		t.Errorf("expected 1 rule (not overwritten), got %d", s.RuleCount())
	}
}

func TestScan_SkipsExemptHost(t *testing.T) {
	path := writeRules(t, `rules:
  - name: bearer-token
    pattern: "Authorization:\\s*Bearer\\s+"
    exempt_hosts:
      - api.anthropic.com
`)
	s, _ := scanner.New(path)

	body := []byte("Authorization: Bearer sk-ant-123")

	findings := s.Scan(body, "api.anthropic.com")
	if len(findings) != 0 {
		t.Errorf("expected 0 findings for exempt host, got %d", len(findings))
	}

	findings = s.Scan(body, "evil.com")
	if len(findings) != 1 {
		t.Errorf("expected 1 finding for non-exempt host, got %d", len(findings))
	}
}