a73x

197849ba

fix: support HTTP/1.1 keep-alive in MITM CONNECT tunnel

a73x   2026-03-29 16:41

Loop over requests in the tunnel instead of handling only one,
so clients that reuse connections work correctly.

diff --git a/proxy/proxy.go b/proxy/proxy.go
index 109d2a7..1a6c468 100644
--- a/proxy/proxy.go
+++ b/proxy/proxy.go
@@ -174,31 +174,6 @@ func (p *Proxy) handleConnectMITM(w http.ResponseWriter, r *http.Request) {
	}
	defer tlsConn.Close()

	req, err := http.ReadRequest(bufio.NewReader(tlsConn))
	if err != nil {
		log.Printf("ERROR: failed to read request from TLS conn for %s: %v", host, err)
		return
	}

	findings := p.scanRequest(req)
	if len(findings) > 0 {
		rules := make([]string, 0, len(findings))
		for _, f := range findings {
			rules = append(rules, f.Rule)
		}
		log.Printf("BLOCKED %s %s [%s]", req.Method, r.Host, strings.Join(rules, ", "))
		resp := &http.Response{
			StatusCode: http.StatusForbidden,
			ProtoMajor: 1,
			ProtoMinor: 1,
			Header:     make(http.Header),
			Body:       io.NopCloser(strings.NewReader(fmt.Sprintf("request blocked: contains sensitive data (%s)\n", strings.Join(rules, ", ")))),
		}
		resp.Header.Set("Content-Type", "text/plain")
		resp.Write(tlsConn)
		return
	}

	upstreamCfg := p.upstreamTLS
	if upstreamCfg == nil {
		upstreamCfg = &tls.Config{}
@@ -209,32 +184,61 @@ func (p *Proxy) handleConnectMITM(w http.ResponseWriter, r *http.Request) {
	upstreamConn, err := tls.Dial("tcp", r.Host, cfg)
	if err != nil {
		log.Printf("ERROR: failed to dial upstream %s: %v", r.Host, err)
		resp := &http.Response{
			StatusCode: http.StatusBadGateway,
			ProtoMajor: 1,
			ProtoMinor: 1,
			Header:     make(http.Header),
			Body:       io.NopCloser(strings.NewReader(err.Error()+"\n")),
		}
		resp.Write(tlsConn)
		return
	}
	defer upstreamConn.Close()

	req.RequestURI = ""
	if err := req.Write(upstreamConn); err != nil {
		log.Printf("ERROR: failed to forward request to %s: %v", r.Host, err)
		return
	}
	clientReader := bufio.NewReader(tlsConn)
	upstreamReader := bufio.NewReader(upstreamConn)

	upstreamResp, err := http.ReadResponse(bufio.NewReader(upstreamConn), req)
	if err != nil {
		log.Printf("ERROR: failed to read response from %s: %v", r.Host, err)
		return
	}
	defer upstreamResp.Body.Close()
	for {
		req, err := http.ReadRequest(clientReader)
		if err != nil {
			if err != io.EOF {
				log.Printf("ERROR: failed to read request from TLS conn for %s: %v", host, err)
			}
			return
		}

		findings := p.scanRequest(req)
		if len(findings) > 0 {
			rules := make([]string, 0, len(findings))
			for _, f := range findings {
				rules = append(rules, f.Rule)
			}
			log.Printf("BLOCKED %s %s [%s]", req.Method, r.Host, strings.Join(rules, ", "))
			resp := &http.Response{
				StatusCode: http.StatusForbidden,
				ProtoMajor: 1,
				ProtoMinor: 1,
				Header:     make(http.Header),
				Body:       io.NopCloser(strings.NewReader(fmt.Sprintf("request blocked: contains sensitive data (%s)\n", strings.Join(rules, ", ")))),
			}
			resp.Header.Set("Content-Type", "text/plain")
			resp.Header.Set("Connection", "close")
			resp.Write(tlsConn)
			return
		}

	upstreamResp.Write(tlsConn)
		req.RequestURI = ""
		if err := req.Write(upstreamConn); err != nil {
			log.Printf("ERROR: failed to forward request to %s: %v", r.Host, err)
			return
		}

		upstreamResp, err := http.ReadResponse(upstreamReader, req)
		if err != nil {
			log.Printf("ERROR: failed to read response from %s: %v", r.Host, err)
			return
		}

		upstreamResp.Write(tlsConn)
		upstreamResp.Body.Close()

		if req.Close || upstreamResp.Close {
			return
		}
	}
}

func (p *Proxy) getOrCreateLeaf(host string) (*tls.Certificate, error) {