mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-06-02 21:13:44 +08:00
fix(proxy): support HTTP CONNECT dialer
This commit is contained in:
@@ -34,7 +34,7 @@ func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||
if cfg != nil {
|
||||
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
|
||||
log.Errorf("failed to configure proxy dialer for %q: %v", proxyutil.Redact(cfg.ProxyURL), errBuild)
|
||||
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||
dialer = proxyDialer
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
return httpClient
|
||||
}
|
||||
// If proxy setup failed, log and fall through to context RoundTripper
|
||||
log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyURL)
|
||||
log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyutil.Redact(proxyURL))
|
||||
}
|
||||
|
||||
// Priority 3: Use RoundTripper from context (typically from RoundTripperFor)
|
||||
|
||||
@@ -30,7 +30,7 @@ func newUtlsRoundTripper(proxyURL string) *utlsRoundTripper {
|
||||
if proxyURL != "" {
|
||||
proxyDialer, mode, errBuild := proxyutil.BuildDialer(proxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyURL, errBuild)
|
||||
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyutil.Redact(proxyURL), errBuild)
|
||||
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||
dialer = proxyDialer
|
||||
}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package proxyutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -50,7 +53,7 @@ func Parse(raw string) (Setting, error) {
|
||||
parsedURL, errParse := url.Parse(trimmed)
|
||||
if errParse != nil {
|
||||
setting.Mode = ModeInvalid
|
||||
return setting, fmt.Errorf("parse proxy URL failed: %w", errParse)
|
||||
return setting, fmt.Errorf("parse proxy URL failed")
|
||||
}
|
||||
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
setting.Mode = ModeInvalid
|
||||
@@ -134,6 +137,9 @@ func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
|
||||
case ModeDirect:
|
||||
return proxy.Direct, setting.Mode, nil
|
||||
case ModeProxy:
|
||||
if setting.URL.Scheme == "http" || setting.URL.Scheme == "https" {
|
||||
return &httpConnectDialer{proxyURL: setting.URL, dialer: proxy.Direct}, setting.Mode, nil
|
||||
}
|
||||
dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct)
|
||||
if errDialer != nil {
|
||||
return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer)
|
||||
@@ -143,3 +149,118 @@ func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
|
||||
return nil, setting.Mode, nil
|
||||
}
|
||||
}
|
||||
|
||||
type httpConnectDialer struct {
|
||||
proxyURL *url.URL
|
||||
dialer proxy.Dialer
|
||||
}
|
||||
|
||||
func (d *httpConnectDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
proxyConn, errDial := d.dialer.Dial(network, proxyDialAddr(d.proxyURL))
|
||||
if errDial != nil {
|
||||
return nil, fmt.Errorf("dial HTTP proxy failed: %w", errDial)
|
||||
}
|
||||
|
||||
conn := proxyConn
|
||||
if d.proxyURL.Scheme == "https" {
|
||||
tlsConn := tls.Client(conn, &tls.Config{ServerName: d.proxyURL.Hostname()})
|
||||
if errHandshake := tlsConn.Handshake(); errHandshake != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
return nil, fmt.Errorf("HTTPS proxy TLS handshake failed: %w; close failed: %v", errHandshake, errClose)
|
||||
}
|
||||
return nil, fmt.Errorf("HTTPS proxy TLS handshake failed: %w", errHandshake)
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Host: addr},
|
||||
Host: addr,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
if d.proxyURL.User != nil {
|
||||
req.Header.Set("Proxy-Authorization", proxyAuthorization(d.proxyURL.User))
|
||||
}
|
||||
if errWrite := req.Write(conn); errWrite != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
return nil, fmt.Errorf("write CONNECT request failed: %w; close failed: %v", errWrite, errClose)
|
||||
}
|
||||
return nil, fmt.Errorf("write CONNECT request failed: %w", errWrite)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
resp, errRead := http.ReadResponse(reader, req)
|
||||
if errRead != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
return nil, fmt.Errorf("read CONNECT response failed: %w; close failed: %v", errRead, errClose)
|
||||
}
|
||||
return nil, fmt.Errorf("read CONNECT response failed: %w", errRead)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
return nil, fmt.Errorf("proxy CONNECT returned status %s; close failed: %v", resp.Status, errClose)
|
||||
}
|
||||
return nil, fmt.Errorf("proxy CONNECT returned status %s", resp.Status)
|
||||
}
|
||||
|
||||
if reader.Buffered() > 0 {
|
||||
return &bufferedConn{Conn: conn, reader: reader}, nil
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func proxyDialAddr(proxyURL *url.URL) string {
|
||||
port := proxyURL.Port()
|
||||
if port == "" {
|
||||
port = "80"
|
||||
if proxyURL.Scheme == "https" {
|
||||
port = "443"
|
||||
}
|
||||
}
|
||||
return net.JoinHostPort(proxyURL.Hostname(), port)
|
||||
}
|
||||
|
||||
func proxyAuthorization(user *url.Userinfo) string {
|
||||
username := user.Username()
|
||||
password, _ := user.Password()
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
return "Basic " + encoded
|
||||
}
|
||||
|
||||
// Redact returns a log-safe proxy URL with credentials and path-like data removed.
|
||||
func Redact(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parsedURL, errParse := url.Parse(trimmed)
|
||||
if errParse != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
return "<invalid proxy URL>"
|
||||
}
|
||||
|
||||
redacted := &url.URL{
|
||||
Scheme: parsedURL.Scheme,
|
||||
Host: parsedURL.Host,
|
||||
}
|
||||
if parsedURL.User != nil {
|
||||
redacted.User = url.User("redacted")
|
||||
}
|
||||
return redacted.String()
|
||||
}
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
if c.reader.Buffered() > 0 {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
package proxyutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func mustDefaultTransport(t *testing.T) *http.Transport {
|
||||
@@ -159,3 +166,157 @@ func TestBuildHTTPTransportSOCKS5HProxy(t *testing.T) {
|
||||
t.Fatal("expected SOCKS5H transport to have custom DialContext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDialerHTTPProxyCONNECT(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listener, errListen := net.Listen("tcp", "127.0.0.1:0")
|
||||
if errListen != nil {
|
||||
t.Fatalf("net.Listen returned error: %v", errListen)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := listener.Close(); errClose != nil {
|
||||
t.Errorf("listener.Close returned error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
conn, errAccept := listener.Accept()
|
||||
if errAccept != nil {
|
||||
done <- errAccept
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
if errDeadline := conn.SetDeadline(time.Now().Add(5 * time.Second)); errDeadline != nil {
|
||||
done <- errDeadline
|
||||
return
|
||||
}
|
||||
|
||||
req, errRead := http.ReadRequest(bufio.NewReader(conn))
|
||||
if errRead != nil {
|
||||
done <- fmt.Errorf("read CONNECT request failed: %w", errRead)
|
||||
return
|
||||
}
|
||||
if req.Method != http.MethodConnect {
|
||||
done <- fmt.Errorf("method = %s, want CONNECT", req.Method)
|
||||
return
|
||||
}
|
||||
if req.Host != "target.example.com:443" {
|
||||
done <- fmt.Errorf("host = %s, want target.example.com:443", req.Host)
|
||||
return
|
||||
}
|
||||
wantAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))
|
||||
if gotAuth := req.Header.Get("Proxy-Authorization"); gotAuth != wantAuth {
|
||||
done <- fmt.Errorf("Proxy-Authorization = %q, want %q", gotAuth, wantAuth)
|
||||
return
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(conn, "HTTP/1.1 200 Connection Established\r\n\r\nok"); errWrite != nil {
|
||||
done <- fmt.Errorf("write CONNECT response failed: %w", errWrite)
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, 4)
|
||||
n, errReadTunnel := io.ReadFull(conn, buf)
|
||||
if errReadTunnel != nil {
|
||||
done <- fmt.Errorf("read tunneled payload failed after %d bytes: %w", n, errReadTunnel)
|
||||
return
|
||||
}
|
||||
if string(buf) != "ping" {
|
||||
done <- fmt.Errorf("tunneled payload = %q, want ping", string(buf))
|
||||
return
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
dialer, mode, errBuild := BuildDialer("http://user:pass@" + listener.Addr().String())
|
||||
if errBuild != nil {
|
||||
t.Fatalf("BuildDialer returned error: %v", errBuild)
|
||||
}
|
||||
if mode != ModeProxy {
|
||||
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
|
||||
}
|
||||
if dialer == nil {
|
||||
t.Fatal("expected dialer, got nil")
|
||||
}
|
||||
|
||||
conn, errDial := dialer.Dial("tcp", "target.example.com:443")
|
||||
if errDial != nil {
|
||||
t.Fatalf("dialer.Dial returned error: %v", errDial)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
t.Errorf("conn.Close returned error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 2)
|
||||
n, errRead := io.ReadFull(conn, buf)
|
||||
if errRead != nil {
|
||||
t.Fatalf("conn.Read returned error after %d bytes: %v", n, errRead)
|
||||
}
|
||||
if string(buf) != "ok" {
|
||||
t.Fatalf("buffered tunnel payload = %q, want ok", string(buf))
|
||||
}
|
||||
|
||||
if _, errWrite := conn.Write([]byte("ping")); errWrite != nil {
|
||||
t.Fatalf("conn.Write returned error: %v", errWrite)
|
||||
}
|
||||
|
||||
if errServer := <-done; errServer != nil {
|
||||
t.Fatalf("proxy server returned error: %v", errServer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "with credentials",
|
||||
input: "http://user:pass@proxy.example.com:8080/path?token=secret",
|
||||
want: "http://redacted@proxy.example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "without credentials",
|
||||
input: "socks5://proxy.example.com:1080",
|
||||
want: "socks5://proxy.example.com:1080",
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: "bad-value",
|
||||
want: "<invalid proxy URL>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := Redact(tt.input); got != tt.want {
|
||||
t.Fatalf("Redact() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseErrorDoesNotExposeProxyCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := "http://user:secret%@proxy.example.com:8080"
|
||||
_, errParse := Parse(input)
|
||||
if errParse == nil {
|
||||
t.Fatal("expected Parse to return an error")
|
||||
}
|
||||
if strings.Contains(errParse.Error(), input) ||
|
||||
strings.Contains(errParse.Error(), "user") ||
|
||||
strings.Contains(errParse.Error(), "secret") {
|
||||
t.Fatalf("parse error exposes proxy credentials: %q", errParse.Error())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user