mirror of
				https://github.com/go-gitea/gitea.git
				synced 2025-11-03 08:02:36 +09:00 
			
		
		
		
	
		
			
				
	
	
		
			169 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			169 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2018 The Go Authors. All rights reserved.
 | 
						|
// Use of this source code is governed by a BSD-style
 | 
						|
// license that can be found in the LICENSE file.
 | 
						|
 | 
						|
package socks
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"io"
 | 
						|
	"net"
 | 
						|
	"strconv"
 | 
						|
	"time"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	noDeadline   = time.Time{}
 | 
						|
	aLongTimeAgo = time.Unix(1, 0)
 | 
						|
)
 | 
						|
 | 
						|
func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
 | 
						|
	host, port, err := splitHostPort(address)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
 | 
						|
		c.SetDeadline(deadline)
 | 
						|
		defer c.SetDeadline(noDeadline)
 | 
						|
	}
 | 
						|
	if ctx != context.Background() {
 | 
						|
		errCh := make(chan error, 1)
 | 
						|
		done := make(chan struct{})
 | 
						|
		defer func() {
 | 
						|
			close(done)
 | 
						|
			if ctxErr == nil {
 | 
						|
				ctxErr = <-errCh
 | 
						|
			}
 | 
						|
		}()
 | 
						|
		go func() {
 | 
						|
			select {
 | 
						|
			case <-ctx.Done():
 | 
						|
				c.SetDeadline(aLongTimeAgo)
 | 
						|
				errCh <- ctx.Err()
 | 
						|
			case <-done:
 | 
						|
				errCh <- nil
 | 
						|
			}
 | 
						|
		}()
 | 
						|
	}
 | 
						|
 | 
						|
	b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
 | 
						|
	b = append(b, Version5)
 | 
						|
	if len(d.AuthMethods) == 0 || d.Authenticate == nil {
 | 
						|
		b = append(b, 1, byte(AuthMethodNotRequired))
 | 
						|
	} else {
 | 
						|
		ams := d.AuthMethods
 | 
						|
		if len(ams) > 255 {
 | 
						|
			return nil, errors.New("too many authentication methods")
 | 
						|
		}
 | 
						|
		b = append(b, byte(len(ams)))
 | 
						|
		for _, am := range ams {
 | 
						|
			b = append(b, byte(am))
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if _, ctxErr = c.Write(b); ctxErr != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	if b[0] != Version5 {
 | 
						|
		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
 | 
						|
	}
 | 
						|
	am := AuthMethod(b[1])
 | 
						|
	if am == AuthMethodNoAcceptableMethods {
 | 
						|
		return nil, errors.New("no acceptable authentication methods")
 | 
						|
	}
 | 
						|
	if d.Authenticate != nil {
 | 
						|
		if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
 | 
						|
			return
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	b = b[:0]
 | 
						|
	b = append(b, Version5, byte(d.cmd), 0)
 | 
						|
	if ip := net.ParseIP(host); ip != nil {
 | 
						|
		if ip4 := ip.To4(); ip4 != nil {
 | 
						|
			b = append(b, AddrTypeIPv4)
 | 
						|
			b = append(b, ip4...)
 | 
						|
		} else if ip6 := ip.To16(); ip6 != nil {
 | 
						|
			b = append(b, AddrTypeIPv6)
 | 
						|
			b = append(b, ip6...)
 | 
						|
		} else {
 | 
						|
			return nil, errors.New("unknown address type")
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		if len(host) > 255 {
 | 
						|
			return nil, errors.New("FQDN too long")
 | 
						|
		}
 | 
						|
		b = append(b, AddrTypeFQDN)
 | 
						|
		b = append(b, byte(len(host)))
 | 
						|
		b = append(b, host...)
 | 
						|
	}
 | 
						|
	b = append(b, byte(port>>8), byte(port))
 | 
						|
	if _, ctxErr = c.Write(b); ctxErr != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	if b[0] != Version5 {
 | 
						|
		return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
 | 
						|
	}
 | 
						|
	if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
 | 
						|
		return nil, errors.New("unknown error " + cmdErr.String())
 | 
						|
	}
 | 
						|
	if b[2] != 0 {
 | 
						|
		return nil, errors.New("non-zero reserved field")
 | 
						|
	}
 | 
						|
	l := 2
 | 
						|
	var a Addr
 | 
						|
	switch b[3] {
 | 
						|
	case AddrTypeIPv4:
 | 
						|
		l += net.IPv4len
 | 
						|
		a.IP = make(net.IP, net.IPv4len)
 | 
						|
	case AddrTypeIPv6:
 | 
						|
		l += net.IPv6len
 | 
						|
		a.IP = make(net.IP, net.IPv6len)
 | 
						|
	case AddrTypeFQDN:
 | 
						|
		if _, err := io.ReadFull(c, b[:1]); err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
		l += int(b[0])
 | 
						|
	default:
 | 
						|
		return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
 | 
						|
	}
 | 
						|
	if cap(b) < l {
 | 
						|
		b = make([]byte, l)
 | 
						|
	} else {
 | 
						|
		b = b[:l]
 | 
						|
	}
 | 
						|
	if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	if a.IP != nil {
 | 
						|
		copy(a.IP, b)
 | 
						|
	} else {
 | 
						|
		a.Name = string(b[:len(b)-2])
 | 
						|
	}
 | 
						|
	a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
 | 
						|
	return &a, nil
 | 
						|
}
 | 
						|
 | 
						|
func splitHostPort(address string) (string, int, error) {
 | 
						|
	host, port, err := net.SplitHostPort(address)
 | 
						|
	if err != nil {
 | 
						|
		return "", 0, err
 | 
						|
	}
 | 
						|
	portnum, err := strconv.Atoi(port)
 | 
						|
	if err != nil {
 | 
						|
		return "", 0, err
 | 
						|
	}
 | 
						|
	if 1 > portnum || portnum > 0xffff {
 | 
						|
		return "", 0, errors.New("port number out of range " + port)
 | 
						|
	}
 | 
						|
	return host, portnum, nil
 | 
						|
}
 |