mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 13:07:46 +09:00 
			
		
		
		
	Multi network integration tests (#2464)
This commit is contained in:
		| @@ -70,8 +70,9 @@ jobs: | ||||
|           - TestAutoApprovedSubRoute2068 | ||||
|           - TestSubnetRouteACL | ||||
|           - TestEnablingExitRoutes | ||||
|           - TestSubnetRouterMultiNetwork | ||||
|           - TestSubnetRouterMultiNetworkExitNode | ||||
|           - TestHeadscale | ||||
|           - TestCreateTailscale | ||||
|           - TestTailscaleNodesJoiningHeadcale | ||||
|           - TestSSHOneUserToAll | ||||
|           - TestSSHMultipleUsersAllToAll | ||||
|   | ||||
							
								
								
									
										3
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							| @@ -70,8 +70,9 @@ jobs: | ||||
|           - TestAutoApprovedSubRoute2068 | ||||
|           - TestSubnetRouteACL | ||||
|           - TestEnablingExitRoutes | ||||
|           - TestSubnetRouterMultiNetwork | ||||
|           - TestSubnetRouterMultiNetworkExitNode | ||||
|           - TestHeadscale | ||||
|           - TestCreateTailscale | ||||
|           - TestTailscaleNodesJoiningHeadcale | ||||
|           - TestSSHOneUserToAll | ||||
|           - TestSSHMultipleUsersAllToAll | ||||
|   | ||||
| @@ -165,9 +165,13 @@ func Test_fullMapResponse(t *testing.T) { | ||||
| 		), | ||||
| 		Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, | ||||
| 		AllowedIPs: []netip.Prefix{ | ||||
| 			netip.MustParsePrefix("100.64.0.1/32"), | ||||
| 			tsaddr.AllIPv4(), | ||||
| 			netip.MustParsePrefix("192.168.0.0/24"), | ||||
| 			netip.MustParsePrefix("100.64.0.1/32"), | ||||
| 			tsaddr.AllIPv6(), | ||||
| 		}, | ||||
| 		PrimaryRoutes: []netip.Prefix{ | ||||
| 			netip.MustParsePrefix("192.168.0.0/24"), | ||||
| 		}, | ||||
| 		HomeDERP:         0, | ||||
| 		LegacyDERPString: "127.3.3.40:0", | ||||
|   | ||||
| @@ -2,13 +2,13 @@ package mapper | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/juanfont/headscale/hscontrol/policy" | ||||
| 	"github.com/juanfont/headscale/hscontrol/routes" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/samber/lo" | ||||
| 	"tailscale.com/net/tsaddr" | ||||
| 	"tailscale.com/tailcfg" | ||||
| ) | ||||
|  | ||||
| @@ -49,14 +49,6 @@ func tailNode( | ||||
| ) (*tailcfg.Node, error) { | ||||
| 	addrs := node.Prefixes() | ||||
|  | ||||
| 	allowedIPs := append( | ||||
| 		[]netip.Prefix{}, | ||||
| 		addrs...) // we append the node own IP, as it is required by the clients | ||||
|  | ||||
| 	for _, route := range node.SubnetRoutes() { | ||||
| 		allowedIPs = append(allowedIPs, netip.Prefix(route)) | ||||
| 	} | ||||
|  | ||||
| 	var derp int | ||||
|  | ||||
| 	// TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077 | ||||
| @@ -89,6 +81,10 @@ func tailNode( | ||||
| 	} | ||||
| 	tags = lo.Uniq(append(tags, node.ForcedTags...)) | ||||
|  | ||||
| 	allowed := append(node.Prefixes(), primary.PrimaryRoutes(node.ID)...) | ||||
| 	allowed = append(allowed, node.ExitRoutes()...) | ||||
| 	tsaddr.SortPrefixes(allowed) | ||||
|  | ||||
| 	tNode := tailcfg.Node{ | ||||
| 		ID:       tailcfg.NodeID(node.ID), // this is the actual ID | ||||
| 		StableID: node.ID.StableID(), | ||||
| @@ -104,7 +100,7 @@ func tailNode( | ||||
| 		DiscoKey:         node.DiscoKey, | ||||
| 		Addresses:        addrs, | ||||
| 		PrimaryRoutes:    primary.PrimaryRoutes(node.ID), | ||||
| 		AllowedIPs:       allowedIPs, | ||||
| 		AllowedIPs:       allowed, | ||||
| 		Endpoints:        node.Endpoints, | ||||
| 		HomeDERP:         derp, | ||||
| 		LegacyDERPString: legacyDERP, | ||||
|   | ||||
| @@ -67,8 +67,6 @@ func TestTailNode(t *testing.T) { | ||||
| 			want: &tailcfg.Node{ | ||||
| 				Name:              "empty", | ||||
| 				StableID:          "0", | ||||
| 				Addresses:         []netip.Prefix{}, | ||||
| 				AllowedIPs:        []netip.Prefix{}, | ||||
| 				HomeDERP:          0, | ||||
| 				LegacyDERPString:  "127.3.3.40:0", | ||||
| 				Hostinfo:          hiview(tailcfg.Hostinfo{}), | ||||
| @@ -139,9 +137,13 @@ func TestTailNode(t *testing.T) { | ||||
| 				), | ||||
| 				Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, | ||||
| 				AllowedIPs: []netip.Prefix{ | ||||
| 					netip.MustParsePrefix("100.64.0.1/32"), | ||||
| 					tsaddr.AllIPv4(), | ||||
| 					netip.MustParsePrefix("192.168.0.0/24"), | ||||
| 					netip.MustParsePrefix("100.64.0.1/32"), | ||||
| 					tsaddr.AllIPv6(), | ||||
| 				}, | ||||
| 				PrimaryRoutes: []netip.Prefix{ | ||||
| 					netip.MustParsePrefix("192.168.0.0/24"), | ||||
| 				}, | ||||
| 				HomeDERP:         0, | ||||
| 				LegacyDERPString: "127.3.3.40:0", | ||||
| @@ -156,10 +158,6 @@ func TestTailNode(t *testing.T) { | ||||
|  | ||||
| 				Tags: []string{}, | ||||
|  | ||||
| 				PrimaryRoutes: []netip.Prefix{ | ||||
| 					netip.MustParsePrefix("192.168.0.0/24"), | ||||
| 				}, | ||||
|  | ||||
| 				LastSeen:          &lastSeen, | ||||
| 				MachineAuthorized: true, | ||||
|  | ||||
|   | ||||
| @@ -11,6 +11,7 @@ import ( | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	xmaps "golang.org/x/exp/maps" | ||||
| 	"tailscale.com/net/tsaddr" | ||||
| 	"tailscale.com/util/set" | ||||
| ) | ||||
|  | ||||
| @@ -74,18 +75,12 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { | ||||
| 	// If the current primary is not available, select a new one. | ||||
| 	for prefix, nodes := range allPrimaries { | ||||
| 		if node, ok := pr.primaries[prefix]; ok { | ||||
| 			if len(nodes) < 2 { | ||||
| 				delete(pr.primaries, prefix) | ||||
| 				changed = true | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			// If the current primary is still available, continue. | ||||
| 			if slices.Contains(nodes, node) { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
| 		if len(nodes) >= 2 { | ||||
| 		if len(nodes) >= 1 { | ||||
| 			pr.primaries[prefix] = nodes[0] | ||||
| 			changed = true | ||||
| 		} | ||||
| @@ -107,12 +102,16 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { | ||||
| 	return changed | ||||
| } | ||||
|  | ||||
| func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bool { | ||||
| // SetRoutes sets the routes for a given Node ID and recalculates the primary routes | ||||
| // of the headscale. | ||||
| // It returns true if there was a change in primary routes. | ||||
| // All exit routes are ignored as they are not used in primary route context. | ||||
| func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) bool { | ||||
| 	pr.mu.Lock() | ||||
| 	defer pr.mu.Unlock() | ||||
|  | ||||
| 	// If no routes are being set, remove the node from the routes map. | ||||
| 	if len(prefix) == 0 { | ||||
| 	if len(prefixes) == 0 { | ||||
| 		if _, ok := pr.routes[node]; ok { | ||||
| 			delete(pr.routes, node) | ||||
| 			return pr.updatePrimaryLocked() | ||||
| @@ -121,12 +120,17 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefix ...netip.Prefix) bo | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if _, ok := pr.routes[node]; !ok { | ||||
| 		pr.routes[node] = make(set.Set[netip.Prefix], len(prefix)) | ||||
| 	rs := make(set.Set[netip.Prefix], len(prefixes)) | ||||
| 	for _, prefix := range prefixes { | ||||
| 		if !tsaddr.IsExitRoute(prefix) { | ||||
| 			rs.Add(prefix) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	for _, p := range prefix { | ||||
| 		pr.routes[node].Add(p) | ||||
| 	if rs.Len() != 0 { | ||||
| 		pr.routes[node] = rs | ||||
| 	} else { | ||||
| 		delete(pr.routes, node) | ||||
| 	} | ||||
|  | ||||
| 	return pr.updatePrimaryLocked() | ||||
| @@ -153,6 +157,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	tsaddr.SortPrefixes(routes) | ||||
| 	return routes | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -6,8 +6,10 @@ import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"github.com/google/go-cmp/cmp/cmpopts" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"tailscale.com/util/set" | ||||
| ) | ||||
|  | ||||
| // mp is a helper function that wraps netip.MustParsePrefix. | ||||
| @@ -19,18 +21,32 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name              string | ||||
| 		operations        func(pr *PrimaryRoutes) bool | ||||
| 		nodeID         types.NodeID | ||||
| 		expectedRoutes []netip.Prefix | ||||
| 		expectedRoutes    map[types.NodeID]set.Set[netip.Prefix] | ||||
| 		expectedPrimaries map[netip.Prefix]types.NodeID | ||||
| 		expectedIsPrimary map[types.NodeID]bool | ||||
| 		expectedChange    bool | ||||
|  | ||||
| 		// primaries is a map of prefixes to the node that is the primary for that prefix. | ||||
| 		primaries map[netip.Prefix]types.NodeID | ||||
| 		isPrimary map[types.NodeID]bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "single-node-registers-single-route", | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| 				return pr.SetRoutes(1, mp("192.168.1.0/24")) | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 1, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				1: true, | ||||
| 			}, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "multiple-nodes-register-different-routes", | ||||
| @@ -38,19 +54,45 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 				pr.SetRoutes(1, mp("192.168.1.0/24")) | ||||
| 				return pr.SetRoutes(2, mp("192.168.2.0/24")) | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				2: { | ||||
| 					mp("192.168.2.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 1, | ||||
| 				mp("192.168.2.0/24"): 2, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				1: true, | ||||
| 				2: true, | ||||
| 			}, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "multiple-nodes-register-overlapping-routes", | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| 				pr.SetRoutes(1, mp("192.168.1.0/24"))        // false | ||||
| 				return pr.SetRoutes(2, mp("192.168.1.0/24")) // true | ||||
| 				pr.SetRoutes(1, mp("192.168.1.0/24"))        // true | ||||
| 				return pr.SetRoutes(2, mp("192.168.1.0/24")) // false | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, | ||||
| 			expectedChange: true, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				2: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 1, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				1: true, | ||||
| 			}, | ||||
| 			expectedChange: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "node-deregisters-a-route", | ||||
| @@ -58,9 +100,10 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 				pr.SetRoutes(1, mp("192.168.1.0/24")) | ||||
| 				return pr.SetRoutes(1) // Deregister by setting no routes | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes:    nil, | ||||
| 			expectedChange: false, | ||||
| 			expectedPrimaries: nil, | ||||
| 			expectedIsPrimary: nil, | ||||
| 			expectedChange:    true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "node-deregisters-one-of-multiple-routes", | ||||
| @@ -68,9 +111,18 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 				pr.SetRoutes(1, mp("192.168.1.0/24"), mp("192.168.2.0/24")) | ||||
| 				return pr.SetRoutes(1, mp("192.168.2.0/24")) // Deregister one route by setting the remaining route | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.2.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.2.0/24"): 1, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				1: true, | ||||
| 			}, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "node-registers-and-deregisters-routes-in-sequence", | ||||
| @@ -80,18 +132,23 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 				pr.SetRoutes(1) // Deregister by setting no routes | ||||
| 				return pr.SetRoutes(1, mp("192.168.3.0/24")) | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.3.0/24"): {}, | ||||
| 				}, | ||||
| 		{ | ||||
| 			name: "no-change-in-primary-routes", | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| 				return pr.SetRoutes(1, mp("192.168.1.0/24")) | ||||
| 				2: { | ||||
| 					mp("192.168.2.0/24"): {}, | ||||
| 				}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.2.0/24"): 2, | ||||
| 				mp("192.168.3.0/24"): 1, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				1: true, | ||||
| 				2: true, | ||||
| 			}, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "multiple-nodes-register-same-route", | ||||
| @@ -100,22 +157,25 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 				pr.SetRoutes(2, mp("192.168.1.0/24"))        // true | ||||
| 				return pr.SetRoutes(3, mp("192.168.1.0/24")) // false | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				2: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				3: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 1, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				1: true, | ||||
| 			}, | ||||
| 			expectedChange: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "register-multiple-routes-shift-primary-check-old-primary", | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| 				pr.SetRoutes(1, mp("192.168.1.0/24")) // false | ||||
| 				pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary | ||||
| 				pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary | ||||
| 				return pr.SetRoutes(1)                // true, 2 primary | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "register-multiple-routes-shift-primary-check-primary", | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| @@ -124,20 +184,20 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 				pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary | ||||
| 				return pr.SetRoutes(1)                // true, 2 primary | ||||
| 			}, | ||||
| 			nodeID:         2, | ||||
| 			expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, | ||||
| 			expectedChange: true, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				2: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 		{ | ||||
| 			name: "register-multiple-routes-shift-primary-check-non-primary", | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| 				pr.SetRoutes(1, mp("192.168.1.0/24")) // false | ||||
| 				pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary | ||||
| 				pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary | ||||
| 				return pr.SetRoutes(1)                // true, 2 primary | ||||
| 				3: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 2, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				2: true, | ||||
| 			}, | ||||
| 			nodeID:         3, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -150,8 +210,17 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
|  | ||||
| 				return pr.SetRoutes(2) // true, no primary | ||||
| 			}, | ||||
| 			nodeID:         2, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				3: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 3, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				3: true, | ||||
| 			}, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -165,9 +234,7 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
|  | ||||
| 				return pr.SetRoutes(3) // false, no primary | ||||
| 			}, | ||||
| 			nodeID:         2, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "primary-route-map-is-cleared-up", | ||||
| @@ -179,8 +246,17 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
|  | ||||
| 				return pr.SetRoutes(2) // true, no primary | ||||
| 			}, | ||||
| 			nodeID:         2, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				3: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 3, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				3: true, | ||||
| 			}, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -193,8 +269,23 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
|  | ||||
| 				return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary | ||||
| 			}, | ||||
| 			nodeID:         2, | ||||
| 			expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				2: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				3: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 2, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				2: true, | ||||
| 			}, | ||||
| 			expectedChange: false, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -207,8 +298,23 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
|  | ||||
| 				return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				2: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				3: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 2, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				2: true, | ||||
| 			}, | ||||
| 			expectedChange: false, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -218,15 +324,30 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 				pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary | ||||
| 				pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary | ||||
| 				pr.SetRoutes(1)                       // true, 2 primary | ||||
| 				pr.SetRoutes(2)                       // true, no primary | ||||
| 				pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 1 primary | ||||
| 				pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary | ||||
| 				pr.SetRoutes(1)                       // true, 2 primary | ||||
| 				pr.SetRoutes(2)                       // true, 3 primary | ||||
| 				pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 3 primary | ||||
| 				pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 3 primary | ||||
| 				pr.SetRoutes(1)                       // true, 3 primary | ||||
|  | ||||
| 				return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary | ||||
| 				return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 3 primary | ||||
| 			}, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				2: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				3: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 3, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				3: true, | ||||
| 			}, | ||||
| 			nodeID:         2, | ||||
| 			expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, | ||||
| 			expectedChange: false, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -235,16 +356,27 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 				pr.SetRoutes(1, mp("0.0.0.0/0"), mp("192.168.1.0/24")) | ||||
| 				return pr.SetRoutes(2, mp("192.168.1.0/24")) | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: []netip.Prefix{mp("192.168.1.0/24")}, | ||||
| 			expectedChange: true, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 				2: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 1, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				1: true, | ||||
| 			}, | ||||
| 			expectedChange: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "deregister-non-existent-route", | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| 				return pr.SetRoutes(1) // Deregister by setting no routes | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 		}, | ||||
| @@ -253,17 +385,27 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| 				return pr.SetRoutes(1) | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "deregister-empty-prefix-list", | ||||
| 			name: "exit-nodes", | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| 				return pr.SetRoutes(1) | ||||
| 				pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0")) | ||||
| 				pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0")) | ||||
| 				return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0")) | ||||
| 			}, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("10.0.0.0/16"): {}, | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("10.0.0.0/16"): 1, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				1: true, | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 		}, | ||||
| 		{ | ||||
| @@ -284,19 +426,23 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
|  | ||||
| 				return change1 || change2 | ||||
| 			}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 			expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ | ||||
| 				1: { | ||||
| 					mp("192.168.1.0/24"): {}, | ||||
| 				}, | ||||
| 		{ | ||||
| 			name: "no-routes-registered", | ||||
| 			operations: func(pr *PrimaryRoutes) bool { | ||||
| 				// No operations | ||||
| 				return false | ||||
| 				2: { | ||||
| 					mp("192.168.2.0/24"): {}, | ||||
| 				}, | ||||
| 			nodeID:         1, | ||||
| 			expectedRoutes: nil, | ||||
| 			expectedChange: false, | ||||
| 			}, | ||||
| 			expectedPrimaries: map[netip.Prefix]types.NodeID{ | ||||
| 				mp("192.168.1.0/24"): 1, | ||||
| 				mp("192.168.2.0/24"): 2, | ||||
| 			}, | ||||
| 			expectedIsPrimary: map[types.NodeID]bool{ | ||||
| 				1: true, | ||||
| 				2: true, | ||||
| 			}, | ||||
| 			expectedChange: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -307,9 +453,15 @@ func TestPrimaryRoutes(t *testing.T) { | ||||
| 			if change != tt.expectedChange { | ||||
| 				t.Errorf("change = %v, want %v", change, tt.expectedChange) | ||||
| 			} | ||||
| 			routes := pr.PrimaryRoutes(tt.nodeID) | ||||
| 			if diff := cmp.Diff(tt.expectedRoutes, routes, util.Comparers...); diff != "" { | ||||
| 				t.Errorf("PrimaryRoutes() mismatch (-want +got):\n%s", diff) | ||||
| 			comps := append(util.Comparers, cmpopts.EquateEmpty()) | ||||
| 			if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" { | ||||
| 				t.Errorf("routes mismatch (-want +got):\n%s", diff) | ||||
| 			} | ||||
| 			if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" { | ||||
| 				t.Errorf("primaries mismatch (-want +got):\n%s", diff) | ||||
| 			} | ||||
| 			if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" { | ||||
| 				t.Errorf("isPrimary mismatch (-want +got):\n%s", diff) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
|   | ||||
| @@ -14,6 +14,7 @@ import ( | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"go4.org/netipx" | ||||
| 	"google.golang.org/protobuf/types/known/timestamppb" | ||||
| 	"tailscale.com/net/tsaddr" | ||||
| 	"tailscale.com/tailcfg" | ||||
| 	"tailscale.com/types/key" | ||||
| ) | ||||
| @@ -213,7 +214,7 @@ func (node *Node) RequestTags() []string { | ||||
| } | ||||
|  | ||||
| func (node *Node) Prefixes() []netip.Prefix { | ||||
| 	addrs := []netip.Prefix{} | ||||
| 	var addrs []netip.Prefix | ||||
| 	for _, nodeAddress := range node.IPs() { | ||||
| 		ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen()) | ||||
| 		addrs = append(addrs, ip) | ||||
| @@ -222,6 +223,19 @@ func (node *Node) Prefixes() []netip.Prefix { | ||||
| 	return addrs | ||||
| } | ||||
|  | ||||
| // ExitRoutes returns a list of both exit routes if the | ||||
| // node has any exit routes enabled. | ||||
| // If none are enabled, it will return nil. | ||||
| func (node *Node) ExitRoutes() []netip.Prefix { | ||||
| 	for _, route := range node.SubnetRoutes() { | ||||
| 		if tsaddr.IsExitRoute(route) { | ||||
| 			return tsaddr.ExitRoutes() | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (node *Node) IPsAsString() []string { | ||||
| 	var ret []string | ||||
|  | ||||
|   | ||||
| @@ -57,6 +57,15 @@ func GenerateRandomStringDNSSafe(size int) (string, error) { | ||||
| 	return str[:size], nil | ||||
| } | ||||
|  | ||||
| func MustGenerateRandomStringDNSSafe(size int) string { | ||||
| 	hash, err := GenerateRandomStringDNSSafe(size) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
|  | ||||
| 	return hash | ||||
| } | ||||
|  | ||||
| func TailNodesToString(nodes []*tailcfg.Node) string { | ||||
| 	temp := make([]string, len(nodes)) | ||||
|  | ||||
|   | ||||
| @@ -3,8 +3,12 @@ package util | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/netip" | ||||
| 	"net/url" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"tailscale.com/util/cmpver" | ||||
| ) | ||||
| @@ -46,3 +50,126 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { | ||||
|  | ||||
| 	return loginURL, nil | ||||
| } | ||||
|  | ||||
| type TraceroutePath struct { | ||||
|   // Hop is the current jump in the total traceroute. | ||||
|   Hop int | ||||
|  | ||||
|   // Hostname is the resolved hostname or IP address identifying the jump | ||||
|   Hostname string | ||||
|  | ||||
|   // IP is the IP address of the jump | ||||
|   IP netip.Addr | ||||
|  | ||||
|   // Latencies is a list of the latencies for this jump | ||||
|   Latencies []time.Duration | ||||
| } | ||||
|  | ||||
| type Traceroute struct { | ||||
|   // Hostname is the resolved hostname or IP address identifying the target | ||||
|   Hostname string | ||||
|  | ||||
|   // IP is the IP address of the target | ||||
|   IP netip.Addr | ||||
|  | ||||
|   // Route is the path taken to reach the target if successful. The list is ordered by the path taken. | ||||
|   Route []TraceroutePath | ||||
|  | ||||
|   // Success indicates if the traceroute was successful. | ||||
|   Success bool | ||||
|  | ||||
|   // Err contains an error if  the traceroute was not successful. | ||||
|   Err error | ||||
| } | ||||
|  | ||||
| // ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct | ||||
| func ParseTraceroute(output string) (Traceroute, error) { | ||||
| 	lines := strings.Split(strings.TrimSpace(output), "\n") | ||||
| 	if len(lines) < 1 { | ||||
| 		return Traceroute{}, errors.New("empty traceroute output") | ||||
| 	} | ||||
|  | ||||
| 	// Parse the header line | ||||
| 	headerRegex := regexp.MustCompile(`traceroute to ([^ ]+) \(([^)]+)\)`) | ||||
| 	headerMatches := headerRegex.FindStringSubmatch(lines[0]) | ||||
| 	if len(headerMatches) != 3 { | ||||
| 		return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0]) | ||||
| 	} | ||||
|  | ||||
| 	hostname := headerMatches[1] | ||||
| 	ipStr := headerMatches[2] | ||||
| 	ip, err := netip.ParseAddr(ipStr) | ||||
| 	if err != nil { | ||||
| 		return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err) | ||||
| 	} | ||||
|  | ||||
| 	result := Traceroute{ | ||||
| 		Hostname: hostname, | ||||
| 		IP:       ip, | ||||
| 		Route:    []TraceroutePath{}, | ||||
| 		Success:  false, | ||||
| 	} | ||||
|  | ||||
| 	// Parse each hop line | ||||
| 	hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(?:([^ ]+) \(([^)]+)\)|(\*))(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?(?:\s+(\d+\.\d+) ms)?`) | ||||
|  | ||||
| 	for i := 1; i < len(lines); i++ { | ||||
| 		matches := hopRegex.FindStringSubmatch(lines[i]) | ||||
| 		if len(matches) == 0 { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		hop, err := strconv.Atoi(matches[1]) | ||||
| 		if err != nil { | ||||
| 			return Traceroute{}, fmt.Errorf("parsing hop number: %w", err) | ||||
| 		} | ||||
|  | ||||
| 		var hopHostname string | ||||
| 		var hopIP netip.Addr | ||||
| 		var latencies []time.Duration | ||||
|  | ||||
| 		// Handle hostname and IP | ||||
| 		if matches[2] != "" && matches[3] != "" { | ||||
| 			hopHostname = matches[2] | ||||
| 			hopIP, err = netip.ParseAddr(matches[3]) | ||||
| 			if err != nil { | ||||
| 				return Traceroute{}, fmt.Errorf("parsing hop IP address %s: %w", matches[3], err) | ||||
| 			} | ||||
| 		} else if matches[4] == "*" { | ||||
| 			hopHostname = "*" | ||||
| 			// No IP for timeouts | ||||
| 		} | ||||
|  | ||||
| 		// Parse latencies | ||||
| 		for j := 5; j <= 7; j++ { | ||||
| 			if matches[j] != "" { | ||||
| 				ms, err := strconv.ParseFloat(matches[j], 64) | ||||
| 				if err != nil { | ||||
| 					return Traceroute{}, fmt.Errorf("parsing latency: %w", err) | ||||
| 				} | ||||
| 				latencies = append(latencies, time.Duration(ms*float64(time.Millisecond))) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		path := TraceroutePath{ | ||||
| 			Hop:       hop, | ||||
| 			Hostname:  hopHostname, | ||||
| 			IP:        hopIP, | ||||
| 			Latencies: latencies, | ||||
| 		} | ||||
|  | ||||
| 		result.Route = append(result.Route, path) | ||||
|  | ||||
| 		// Check if we've reached the target | ||||
| 		if hopIP == ip { | ||||
| 			result.Success = true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// If we didn't reach the target, it's unsuccessful | ||||
| 	if !result.Success { | ||||
| 		result.Err = errors.New("traceroute did not reach target") | ||||
| 	} | ||||
|  | ||||
| 	return result, nil | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,13 @@ | ||||
| package util | ||||
|  | ||||
| import "testing" | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net/netip" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| ) | ||||
|  | ||||
| func TestTailscaleVersionNewerOrEqual(t *testing.T) { | ||||
| 	type args struct { | ||||
| @@ -178,3 +185,186 @@ Success.`, | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestParseTraceroute(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		input   string | ||||
| 		want    Traceroute | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "simple successful traceroute", | ||||
| 			input: `traceroute to 172.24.0.3 (172.24.0.3), 30 hops max, 46 byte packets | ||||
|  1  ts-head-hk0urr.headscale.net (100.64.0.1)  1.135 ms  0.922 ms  0.619 ms | ||||
|  2  172.24.0.3 (172.24.0.3)  0.593 ms  0.549 ms  0.522 ms`, | ||||
| 			want: Traceroute{ | ||||
| 				Hostname: "172.24.0.3", | ||||
| 				IP:       netip.MustParseAddr("172.24.0.3"), | ||||
| 				Route: []TraceroutePath{ | ||||
| 					{ | ||||
| 						Hop:      1, | ||||
| 						Hostname: "ts-head-hk0urr.headscale.net", | ||||
| 						IP:       netip.MustParseAddr("100.64.0.1"), | ||||
| 						Latencies: []time.Duration{ | ||||
| 							1135 * time.Microsecond, | ||||
| 							922 * time.Microsecond, | ||||
| 							619 * time.Microsecond, | ||||
| 						}, | ||||
| 					}, | ||||
| 					{ | ||||
| 						Hop:      2, | ||||
| 						Hostname: "172.24.0.3", | ||||
| 						IP:       netip.MustParseAddr("172.24.0.3"), | ||||
| 						Latencies: []time.Duration{ | ||||
| 							593 * time.Microsecond, | ||||
| 							549 * time.Microsecond, | ||||
| 							522 * time.Microsecond, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				Success: true, | ||||
| 				Err:     nil, | ||||
| 			}, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "traceroute with timeouts", | ||||
| 			input: `traceroute to 8.8.8.8 (8.8.8.8), 30 hops max, 60 byte packets | ||||
|  1  router.local (192.168.1.1)  1.234 ms  1.123 ms  1.121 ms | ||||
|  2  * * * | ||||
|  3  isp-gateway.net (10.0.0.1)  15.678 ms  14.789 ms  15.432 ms | ||||
|  4  8.8.8.8 (8.8.8.8)  20.123 ms  19.876 ms  20.345 ms`, | ||||
| 			want: Traceroute{ | ||||
| 				Hostname: "8.8.8.8", | ||||
| 				IP:       netip.MustParseAddr("8.8.8.8"), | ||||
| 				Route: []TraceroutePath{ | ||||
| 					{ | ||||
| 						Hop:      1, | ||||
| 						Hostname: "router.local", | ||||
| 						IP:       netip.MustParseAddr("192.168.1.1"), | ||||
| 						Latencies: []time.Duration{ | ||||
| 							1234 * time.Microsecond, | ||||
| 							1123 * time.Microsecond, | ||||
| 							1121 * time.Microsecond, | ||||
| 						}, | ||||
| 					}, | ||||
| 					{ | ||||
| 						Hop:      2, | ||||
| 						Hostname: "*", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Hop:      3, | ||||
| 						Hostname: "isp-gateway.net", | ||||
| 						IP:       netip.MustParseAddr("10.0.0.1"), | ||||
| 						Latencies: []time.Duration{ | ||||
| 							15678 * time.Microsecond, | ||||
| 							14789 * time.Microsecond, | ||||
| 							15432 * time.Microsecond, | ||||
| 						}, | ||||
| 					}, | ||||
| 					{ | ||||
| 						Hop:      4, | ||||
| 						Hostname: "8.8.8.8", | ||||
| 						IP:       netip.MustParseAddr("8.8.8.8"), | ||||
| 						Latencies: []time.Duration{ | ||||
| 							20123 * time.Microsecond, | ||||
| 							19876 * time.Microsecond, | ||||
| 							20345 * time.Microsecond, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				Success: true, | ||||
| 				Err:     nil, | ||||
| 			}, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "unsuccessful traceroute", | ||||
| 			input: `traceroute to 10.0.0.99 (10.0.0.99), 5 hops max, 60 byte packets | ||||
|  1  router.local (192.168.1.1)  1.234 ms  1.123 ms  1.121 ms | ||||
|  2  * * * | ||||
|  3  * * * | ||||
|  4  * * * | ||||
|  5  * * *`, | ||||
| 			want: Traceroute{ | ||||
| 				Hostname: "10.0.0.99", | ||||
| 				IP:       netip.MustParseAddr("10.0.0.99"), | ||||
| 				Route: []TraceroutePath{ | ||||
| 					{ | ||||
| 						Hop:      1, | ||||
| 						Hostname: "router.local", | ||||
| 						IP:       netip.MustParseAddr("192.168.1.1"), | ||||
| 						Latencies: []time.Duration{ | ||||
| 							1234 * time.Microsecond, | ||||
| 							1123 * time.Microsecond, | ||||
| 							1121 * time.Microsecond, | ||||
| 						}, | ||||
| 					}, | ||||
| 					{ | ||||
| 						Hop:      2, | ||||
| 						Hostname: "*", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Hop:      3, | ||||
| 						Hostname: "*", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Hop:      4, | ||||
| 						Hostname: "*", | ||||
| 					}, | ||||
| 					{ | ||||
| 						Hop:      5, | ||||
| 						Hostname: "*", | ||||
| 					}, | ||||
| 				}, | ||||
| 				Success: false, | ||||
| 				Err:     errors.New("traceroute did not reach target"), | ||||
| 			}, | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "empty input", | ||||
| 			input:   "", | ||||
| 			want:    Traceroute{}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "invalid header", | ||||
| 			input:   "not a valid traceroute output", | ||||
| 			want:    Traceroute{}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := ParseTraceroute(tt.input) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("ParseTraceroute() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			if tt.wantErr { | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// Special handling for error field since it can't be directly compared with cmp.Diff | ||||
| 			gotErr := got.Err | ||||
| 			wantErr := tt.want.Err | ||||
| 			got.Err = nil | ||||
| 			tt.want.Err = nil | ||||
|  | ||||
| 			if diff := cmp.Diff(tt.want, got, IPComparer); diff != "" { | ||||
| 				t.Errorf("ParseTraceroute() mismatch (-want +got):\n%s", diff) | ||||
| 			} | ||||
|  | ||||
| 			// Now check error field separately | ||||
| 			if (gotErr == nil) != (wantErr == nil) { | ||||
| 				t.Errorf("Error field: got %v, want %v", gotErr, wantErr) | ||||
| 			} else if gotErr != nil && wantErr != nil && gotErr.Error() != wantErr.Error() { | ||||
| 				t.Errorf("Error message: got %q, want %q", gotErr.Error(), wantErr.Error()) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -54,15 +54,16 @@ func aclScenario( | ||||
| 	clientsPerUser int, | ||||
| ) *Scenario { | ||||
| 	t.Helper() | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": clientsPerUser, | ||||
| 		"user2": clientsPerUser, | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: clientsPerUser, | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		[]tsic.Option{ | ||||
| 			// Alpine containers dont have ip6tables set up, which causes | ||||
| 			// tailscaled to stop configuring the wgengine, causing it | ||||
| @@ -96,22 +97,24 @@ func aclScenario( | ||||
| func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
|  | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 2, | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	// NOTE: All want cases currently checks the | ||||
| 	// total count of expected peers, this would | ||||
| 	// typically be the client count of the users | ||||
| 	// they can access minus one (them self). | ||||
| 	tests := map[string]struct { | ||||
| 		users  map[string]int | ||||
| 		users  ScenarioSpec | ||||
| 		policy policyv1.ACLPolicy | ||||
| 		want   map[string]int | ||||
| 	}{ | ||||
| 		// Test that when we have no ACL, each client netmap has | ||||
| 		// the amount of peers of the total amount of clients | ||||
| 		"base-acls": { | ||||
| 			users: map[string]int{ | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			users: spec, | ||||
| 			policy: policyv1.ACLPolicy{ | ||||
| 				ACLs: []policyv1.ACL{ | ||||
| 					{ | ||||
| @@ -129,10 +132,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 		// each other, each node has only the number of pairs from | ||||
| 		// their own user. | ||||
| 		"two-isolated-users": { | ||||
| 			users: map[string]int{ | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			users: spec, | ||||
| 			policy: policyv1.ACLPolicy{ | ||||
| 				ACLs: []policyv1.ACL{ | ||||
| 					{ | ||||
| @@ -155,10 +155,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 		// are restricted to a single port, nodes are still present | ||||
| 		// in the netmap. | ||||
| 		"two-restricted-present-in-netmap": { | ||||
| 			users: map[string]int{ | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			users: spec, | ||||
| 			policy: policyv1.ACLPolicy{ | ||||
| 				ACLs: []policyv1.ACL{ | ||||
| 					{ | ||||
| @@ -192,10 +189,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 		// of peers. This will still result in all the peers as we | ||||
| 		// need them present on the other side for the "return path". | ||||
| 		"two-ns-one-isolated": { | ||||
| 			users: map[string]int{ | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			users: spec, | ||||
| 			policy: policyv1.ACLPolicy{ | ||||
| 				ACLs: []policyv1.ACL{ | ||||
| 					{ | ||||
| @@ -220,10 +214,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 			}, | ||||
| 		}, | ||||
| 		"very-large-destination-prefix-1372": { | ||||
| 			users: map[string]int{ | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			users: spec, | ||||
| 			policy: policyv1.ACLPolicy{ | ||||
| 				ACLs: []policyv1.ACL{ | ||||
| 					{ | ||||
| @@ -248,10 +239,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
| 			}, | ||||
| 		}, | ||||
| 		"ipv6-acls-1470": { | ||||
| 			users: map[string]int{ | ||||
| 				"user1": 2, | ||||
| 				"user2": 2, | ||||
| 			}, | ||||
| 			users: spec, | ||||
| 			policy: policyv1.ACLPolicy{ | ||||
| 				ACLs: []policyv1.ACL{ | ||||
| 					{ | ||||
| @@ -269,12 +257,11 @@ func TestACLHostsInNetMapTable(t *testing.T) { | ||||
|  | ||||
| 	for name, testCase := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 			caseSpec := testCase.users | ||||
| 			scenario, err := NewScenario(caseSpec) | ||||
| 			require.NoError(t, err) | ||||
|  | ||||
| 			spec := testCase.users | ||||
|  | ||||
| 			err = scenario.CreateHeadscaleEnv(spec, | ||||
| 			err = scenario.CreateHeadscaleEnv( | ||||
| 				[]tsic.Option{}, | ||||
| 				hsic.WithACLPolicy(&testCase.policy), | ||||
| 			) | ||||
| @@ -944,6 +931,7 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { | ||||
| 	for name, testCase := range tests { | ||||
| 		t.Run(name, func(t *testing.T) { | ||||
| 			scenario := aclScenario(t, &testCase.policy, 1) | ||||
| 			defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 			test1ip := netip.MustParseAddr("100.64.0.1") | ||||
| 			test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1") | ||||
| @@ -1022,16 +1010,16 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	require.NoError(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 1, | ||||
| 		"user2": 1, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		[]tsic.Option{ | ||||
| 			// Alpine containers dont have ip6tables set up, which causes | ||||
| 			// tailscaled to stop configuring the wgengine, causing it | ||||
|   | ||||
| @@ -19,15 +19,15 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { | ||||
|  | ||||
| 	for _, https := range []bool{true, false} { | ||||
| 		t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { | ||||
| 			scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 			spec := ScenarioSpec{ | ||||
| 				NodesPerUser: len(MustTestVersions), | ||||
| 				Users:        []string{"user1", "user2"}, | ||||
| 			} | ||||
|  | ||||
| 			scenario, err := NewScenario(spec) | ||||
| 			assertNoErr(t, err) | ||||
| 			defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 			spec := map[string]int{ | ||||
| 				"user1": len(MustTestVersions), | ||||
| 				"user2": len(MustTestVersions), | ||||
| 			} | ||||
|  | ||||
| 			opts := []hsic.Option{hsic.WithTestName("pingallbyip")} | ||||
| 			if https { | ||||
| 				opts = append(opts, []hsic.Option{ | ||||
| @@ -35,7 +35,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { | ||||
| 				}...) | ||||
| 			} | ||||
|  | ||||
| 			err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) | ||||
| 			err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) | ||||
| 			assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 			allClients, err := scenario.ListTailscaleClients() | ||||
| @@ -84,7 +84,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { | ||||
| 				time.Sleep(5 * time.Minute) | ||||
| 			} | ||||
|  | ||||
| 			for userName := range spec { | ||||
| 			for _, userName := range spec.Users { | ||||
| 				key, err := scenario.CreatePreAuthKey(userName, true, false) | ||||
| 				if err != nil { | ||||
| 					t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) | ||||
| @@ -152,16 +152,16 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 		"user2": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, | ||||
| 		hsic.WithTestName("keyrelognewuser"), | ||||
| 		hsic.WithTLS(), | ||||
| 	) | ||||
| @@ -203,7 +203,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { | ||||
|  | ||||
| 	// Log in all clients as user1, iterating over the spec only returns the | ||||
| 	// clients, not the usernames. | ||||
| 	for userName := range spec { | ||||
| 	for _, userName := range spec.Users { | ||||
| 		err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) | ||||
| @@ -235,15 +235,15 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { | ||||
|  | ||||
| 	for _, https := range []bool{true, false} { | ||||
| 		t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { | ||||
| 			scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 			spec := ScenarioSpec{ | ||||
| 				NodesPerUser: len(MustTestVersions), | ||||
| 				Users:        []string{"user1", "user2"}, | ||||
| 			} | ||||
|  | ||||
| 			scenario, err := NewScenario(spec) | ||||
| 			assertNoErr(t, err) | ||||
| 			defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 			spec := map[string]int{ | ||||
| 				"user1": len(MustTestVersions), | ||||
| 				"user2": len(MustTestVersions), | ||||
| 			} | ||||
|  | ||||
| 			opts := []hsic.Option{hsic.WithTestName("pingallbyip")} | ||||
| 			if https { | ||||
| 				opts = append(opts, []hsic.Option{ | ||||
| @@ -251,7 +251,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { | ||||
| 				}...) | ||||
| 			} | ||||
|  | ||||
| 			err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) | ||||
| 			err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) | ||||
| 			assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 			allClients, err := scenario.ListTailscaleClients() | ||||
| @@ -300,7 +300,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { | ||||
| 				time.Sleep(5 * time.Minute) | ||||
| 			} | ||||
|  | ||||
| 			for userName := range spec { | ||||
| 			for _, userName := range spec.Users { | ||||
| 				key, err := scenario.CreatePreAuthKey(userName, true, false) | ||||
| 				if err != nil { | ||||
| 					t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) | ||||
|   | ||||
| @@ -1,93 +1,58 @@ | ||||
| package integration | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/http/cookiejar" | ||||
| 	"net/netip" | ||||
| 	"net/url" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"maps" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"github.com/google/go-cmp/cmp/cmpopts" | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/juanfont/headscale/integration/dockertestutil" | ||||
| 	"github.com/juanfont/headscale/integration/hsic" | ||||
| 	"github.com/juanfont/headscale/integration/tsic" | ||||
| 	"github.com/oauth2-proxy/mockoidc" | ||||
| 	"github.com/ory/dockertest/v3" | ||||
| 	"github.com/ory/dockertest/v3/docker" | ||||
| 	"github.com/samber/lo" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	dockerContextPath      = "../." | ||||
| 	hsicOIDCMockHashLength = 6 | ||||
| 	defaultAccessTTL       = 10 * time.Minute | ||||
| ) | ||||
|  | ||||
| var errStatusCodeNotOK = errors.New("status code not OK") | ||||
|  | ||||
| type AuthOIDCScenario struct { | ||||
| 	*Scenario | ||||
|  | ||||
| 	mockOIDC *dockertest.Resource | ||||
| } | ||||
|  | ||||
| func TestOIDCAuthenticationPingAll(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	scenario := AuthOIDCScenario{ | ||||
| 		Scenario: baseScenario, | ||||
| 	} | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	// Logins to MockOIDC is served by a queue with a strict order, | ||||
| 	// if we use more than one node per user, the order of the logins | ||||
| 	// will not be deterministic and the test will fail. | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 1, | ||||
| 		"user2": 1, | ||||
| 	} | ||||
|  | ||||
| 	mockusers := []mockoidc.MockUser{ | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 		OIDCUsers: []mockoidc.MockUser{ | ||||
| 			oidcMockUser("user1", true), | ||||
| 			oidcMockUser("user2", false), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 	defer scenario.mockOIDC.Close() | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	oidcMap := map[string]string{ | ||||
| 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||
| 		"HEADSCALE_OIDC_ISSUER":             scenario.mockOIDC.Issuer(), | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          scenario.mockOIDC.ClientID(), | ||||
| 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 	err = scenario.CreateHeadscaleEnvWithLoginURL( | ||||
| 		nil, | ||||
| 		hsic.WithTestName("oidcauthping"), | ||||
| 		hsic.WithConfigEnv(oidcMap), | ||||
| 		hsic.WithTLS(), | ||||
| 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||
| 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), | ||||
| 	) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| @@ -126,7 +91,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { | ||||
| 			Name:       "user1", | ||||
| 			Email:      "user1@headscale.net", | ||||
| 			Provider:   "oidc", | ||||
| 			ProviderId: oidcConfig.Issuer + "/user1", | ||||
| 			ProviderId: scenario.mockOIDC.Issuer() + "/user1", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:    3, | ||||
| @@ -138,7 +103,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { | ||||
| 			Name:       "user2", | ||||
| 			Email:      "", // Unverified | ||||
| 			Provider:   "oidc", | ||||
| 			ProviderId: oidcConfig.Issuer + "/user2", | ||||
| 			ProviderId: scenario.mockOIDC.Issuer() + "/user2", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -158,37 +123,29 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { | ||||
|  | ||||
| 	shortAccessTTL := 5 * time.Minute | ||||
|  | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	baseScenario.pool.MaxWait = 5 * time.Minute | ||||
|  | ||||
| 	scenario := AuthOIDCScenario{ | ||||
| 		Scenario: baseScenario, | ||||
| 	} | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 1, | ||||
| 		"user2": 1, | ||||
| 	} | ||||
|  | ||||
| 	oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{ | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 		OIDCUsers: []mockoidc.MockUser{ | ||||
| 			oidcMockUser("user1", true), | ||||
| 			oidcMockUser("user2", false), | ||||
| 	}) | ||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 	defer scenario.mockOIDC.Close() | ||||
| 		}, | ||||
| 		OIDCAccessTTL: shortAccessTTL, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	oidcMap := map[string]string{ | ||||
| 		"HEADSCALE_OIDC_ISSUER":                oidcConfig.Issuer, | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":             oidcConfig.ClientID, | ||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET":         oidcConfig.ClientSecret, | ||||
| 		"HEADSCALE_OIDC_ISSUER":                scenario.mockOIDC.Issuer(), | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":             scenario.mockOIDC.ClientID(), | ||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET":         scenario.mockOIDC.ClientSecret(), | ||||
| 		"HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1", | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 	err = scenario.CreateHeadscaleEnvWithLoginURL( | ||||
| 		nil, | ||||
| 		hsic.WithTestName("oidcexpirenodes"), | ||||
| 		hsic.WithConfigEnv(oidcMap), | ||||
| 	) | ||||
| @@ -334,45 +291,35 @@ func TestOIDC024UserCreation(t *testing.T) { | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 			assertNoErr(t, err) | ||||
|  | ||||
| 			scenario := AuthOIDCScenario{ | ||||
| 				Scenario: baseScenario, | ||||
| 			spec := ScenarioSpec{ | ||||
| 				NodesPerUser: 1, | ||||
| 			} | ||||
| 			for _, user := range tt.cliUsers { | ||||
| 				spec.Users = append(spec.Users, user) | ||||
| 			} | ||||
|  | ||||
| 			for _, user := range tt.oidcUsers { | ||||
| 				spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified)) | ||||
| 			} | ||||
|  | ||||
| 			scenario, err := NewScenario(spec) | ||||
| 			assertNoErr(t, err) | ||||
| 			defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 			spec := map[string]int{} | ||||
| 			for _, user := range tt.cliUsers { | ||||
| 				spec[user] = 1 | ||||
| 			} | ||||
|  | ||||
| 			var mockusers []mockoidc.MockUser | ||||
| 			for _, user := range tt.oidcUsers { | ||||
| 				mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified)) | ||||
| 			} | ||||
|  | ||||
| 			oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||
| 			assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 			defer scenario.mockOIDC.Close() | ||||
|  | ||||
| 			oidcMap := map[string]string{ | ||||
| 				"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||
| 				"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||
| 				"HEADSCALE_OIDC_ISSUER":             scenario.mockOIDC.Issuer(), | ||||
| 				"HEADSCALE_OIDC_CLIENT_ID":          scenario.mockOIDC.ClientID(), | ||||
| 				"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||
| 				"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||
| 			} | ||||
| 			maps.Copy(oidcMap, tt.config) | ||||
|  | ||||
| 			for k, v := range tt.config { | ||||
| 				oidcMap[k] = v | ||||
| 			} | ||||
|  | ||||
| 			err = scenario.CreateHeadscaleEnv( | ||||
| 				spec, | ||||
| 			err = scenario.CreateHeadscaleEnvWithLoginURL( | ||||
| 				nil, | ||||
| 				hsic.WithTestName("oidcmigration"), | ||||
| 				hsic.WithConfigEnv(oidcMap), | ||||
| 				hsic.WithTLS(), | ||||
| 				hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||
| 				hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), | ||||
| 			) | ||||
| 			assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| @@ -384,7 +331,7 @@ func TestOIDC024UserCreation(t *testing.T) { | ||||
| 			headscale, err := scenario.Headscale() | ||||
| 			assertNoErr(t, err) | ||||
|  | ||||
| 			want := tt.want(oidcConfig.Issuer) | ||||
| 			want := tt.want(scenario.mockOIDC.Issuer()) | ||||
|  | ||||
| 			listUsers, err := headscale.ListUsers() | ||||
| 			assertNoErr(t, err) | ||||
| @@ -404,41 +351,33 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	scenario := AuthOIDCScenario{ | ||||
| 		Scenario: baseScenario, | ||||
| 	// Single user with one node for testing PKCE flow | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{"user1"}, | ||||
| 		OIDCUsers: []mockoidc.MockUser{ | ||||
| 			oidcMockUser("user1", true), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	// Single user with one node for testing PKCE flow | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 1, | ||||
| 	} | ||||
|  | ||||
| 	mockusers := []mockoidc.MockUser{ | ||||
| 		oidcMockUser("user1", true), | ||||
| 	} | ||||
|  | ||||
| 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 	defer scenario.mockOIDC.Close() | ||||
|  | ||||
| 	oidcMap := map[string]string{ | ||||
| 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||
| 		"HEADSCALE_OIDC_ISSUER":             scenario.mockOIDC.Issuer(), | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          scenario.mockOIDC.ClientID(), | ||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||
| 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||
| 		"HEADSCALE_OIDC_PKCE_ENABLED":       "1", // Enable PKCE | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 	err = scenario.CreateHeadscaleEnvWithLoginURL( | ||||
| 		nil, | ||||
| 		hsic.WithTestName("oidcauthpkce"), | ||||
| 		hsic.WithConfigEnv(oidcMap), | ||||
| 		hsic.WithTLS(), | ||||
| 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||
| 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), | ||||
| 	) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| @@ -464,43 +403,33 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	scenario := AuthOIDCScenario{ | ||||
| 		Scenario: baseScenario, | ||||
| 	} | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	// Create no nodes and no users | ||||
| 	spec := map[string]int{} | ||||
|  | ||||
| 	scenario, err := NewScenario(ScenarioSpec{ | ||||
| 		// First login creates the first OIDC user | ||||
| 		// Second login logs in the same node, which creates a new node | ||||
| 		// Third login logs in the same node back into the original user | ||||
| 	mockusers := []mockoidc.MockUser{ | ||||
| 		OIDCUsers: []mockoidc.MockUser{ | ||||
| 			oidcMockUser("user1", true), | ||||
| 			oidcMockUser("user2", true), | ||||
| 			oidcMockUser("user1", true), | ||||
| 	} | ||||
|  | ||||
| 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||
| 	// defer scenario.mockOIDC.Close() | ||||
| 		}, | ||||
| 	}) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	oidcMap := map[string]string{ | ||||
| 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||
| 		"HEADSCALE_OIDC_ISSUER":             scenario.mockOIDC.Issuer(), | ||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          scenario.mockOIDC.ClientID(), | ||||
| 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 	err = scenario.CreateHeadscaleEnvWithLoginURL( | ||||
| 		nil, | ||||
| 		hsic.WithTestName("oidcauthrelog"), | ||||
| 		hsic.WithConfigEnv(oidcMap), | ||||
| 		hsic.WithTLS(), | ||||
| 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||
| 		hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), | ||||
| 		hsic.WithEmbeddedDERPServerOnly(), | ||||
| 	) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
| @@ -512,7 +441,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { | ||||
| 	assertNoErr(t, err) | ||||
| 	assert.Len(t, listUsers, 0) | ||||
|  | ||||
| 	ts, err := scenario.CreateTailscaleNode("unstable") | ||||
| 	ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	u, err := ts.LoginWithURL(headscale.GetEndpoint()) | ||||
| @@ -530,7 +459,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { | ||||
| 			Name:       "user1", | ||||
| 			Email:      "user1@headscale.net", | ||||
| 			Provider:   "oidc", | ||||
| 			ProviderId: oidcConfig.Issuer + "/user1", | ||||
| 			ProviderId: scenario.mockOIDC.Issuer() + "/user1", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -575,14 +504,14 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { | ||||
| 			Name:       "user1", | ||||
| 			Email:      "user1@headscale.net", | ||||
| 			Provider:   "oidc", | ||||
| 			ProviderId: oidcConfig.Issuer + "/user1", | ||||
| 			ProviderId: scenario.mockOIDC.Issuer() + "/user1", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         2, | ||||
| 			Name:       "user2", | ||||
| 			Email:      "user2@headscale.net", | ||||
| 			Provider:   "oidc", | ||||
| 			ProviderId: oidcConfig.Issuer + "/user2", | ||||
| 			ProviderId: scenario.mockOIDC.Issuer() + "/user2", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -632,14 +561,14 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { | ||||
| 			Name:       "user1", | ||||
| 			Email:      "user1@headscale.net", | ||||
| 			Provider:   "oidc", | ||||
| 			ProviderId: oidcConfig.Issuer + "/user1", | ||||
| 			ProviderId: scenario.mockOIDC.Issuer() + "/user1", | ||||
| 		}, | ||||
| 		{ | ||||
| 			Id:         2, | ||||
| 			Name:       "user2", | ||||
| 			Email:      "user2@headscale.net", | ||||
| 			Provider:   "oidc", | ||||
| 			ProviderId: oidcConfig.Issuer + "/user2", | ||||
| 			ProviderId: scenario.mockOIDC.Issuer() + "/user2", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| @@ -678,254 +607,6 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { | ||||
| 	assert.NotEqual(t, listNodesAfterLoggingBackIn[0].NodeKey, listNodesAfterLoggingBackIn[1].NodeKey) | ||||
| } | ||||
|  | ||||
| func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||
| 	users map[string]int, | ||||
| 	opts ...hsic.Option, | ||||
| ) error { | ||||
| 	headscale, err := s.Headscale(opts...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = headscale.WaitForRunning() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for userName, clientCount := range users { | ||||
| 		if clientCount != 1 { | ||||
| 			// OIDC scenario only supports one client per user. | ||||
| 			// This is because the MockOIDC server can only serve login | ||||
| 			// requests based on a queue it has been given on startup. | ||||
| 			// We currently only populates it with one login request per user. | ||||
| 			return fmt.Errorf("client count must be 1 for OIDC scenario.") | ||||
| 		} | ||||
| 		log.Printf("creating user %s with %d clients", userName, clientCount) | ||||
| 		err = s.CreateUser(userName) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = s.CreateTailscaleNodesInUser(userName, "all", clientCount) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = s.runTailscaleUp(userName, headscale.GetEndpoint()) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { | ||||
| 	port, err := dockertestutil.RandomFreeHostPort() | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("could not find an open port: %s", err) | ||||
| 	} | ||||
| 	portNotation := fmt.Sprintf("%d/tcp", port) | ||||
|  | ||||
| 	hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) | ||||
|  | ||||
| 	hostname := fmt.Sprintf("hs-oidcmock-%s", hash) | ||||
|  | ||||
| 	usersJSON, err := json.Marshal(users) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	mockOidcOptions := &dockertest.RunOptions{ | ||||
| 		Name:         hostname, | ||||
| 		Cmd:          []string{"headscale", "mockoidc"}, | ||||
| 		ExposedPorts: []string{portNotation}, | ||||
| 		PortBindings: map[docker.Port][]docker.PortBinding{ | ||||
| 			docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, | ||||
| 		}, | ||||
| 		Networks: []*dockertest.Network{s.Scenario.network}, | ||||
| 		Env: []string{ | ||||
| 			fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname), | ||||
| 			fmt.Sprintf("MOCKOIDC_PORT=%d", port), | ||||
| 			"MOCKOIDC_CLIENT_ID=superclient", | ||||
| 			"MOCKOIDC_CLIENT_SECRET=supersecret", | ||||
| 			fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), | ||||
| 			fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	headscaleBuildOptions := &dockertest.BuildOptions{ | ||||
| 		Dockerfile: hsic.IntegrationTestDockerFileName, | ||||
| 		ContextDir: dockerContextPath, | ||||
| 	} | ||||
|  | ||||
| 	err = s.pool.RemoveContainerByName(hostname) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( | ||||
| 		headscaleBuildOptions, | ||||
| 		mockOidcOptions, | ||||
| 		dockertestutil.DockerRestartPolicy); err == nil { | ||||
| 		s.mockOIDC = pmockoidc | ||||
| 	} else { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	log.Println("Waiting for headscale mock oidc to be ready for tests") | ||||
| 	hostEndpoint := fmt.Sprintf("%s:%d", s.mockOIDC.GetIPInNetwork(s.network), port) | ||||
|  | ||||
| 	if err := s.pool.Retry(func() error { | ||||
| 		oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) | ||||
| 		httpClient := &http.Client{} | ||||
| 		ctx := context.Background() | ||||
| 		req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil) | ||||
| 		resp, err := httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			log.Printf("headscale mock OIDC tests is not ready: %s\n", err) | ||||
|  | ||||
| 			return err | ||||
| 		} | ||||
| 		defer resp.Body.Close() | ||||
|  | ||||
| 		if resp.StatusCode != http.StatusOK { | ||||
| 			return errStatusCodeNotOK | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	}); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint) | ||||
|  | ||||
| 	return &types.OIDCConfig{ | ||||
| 		Issuer: fmt.Sprintf( | ||||
| 			"http://%s/oidc", | ||||
| 			net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port)), | ||||
| 		), | ||||
| 		ClientID:                   "superclient", | ||||
| 		ClientSecret:               "supersecret", | ||||
| 		OnlyStartIfOIDCIsAvailable: true, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| type LoggingRoundTripper struct{} | ||||
|  | ||||
| func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | ||||
| 	noTls := &http.Transport{ | ||||
| 		TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint | ||||
| 	} | ||||
| 	resp, err := noTls.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("---") | ||||
| 	log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String()) | ||||
| 	log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies()) | ||||
|  | ||||
| 	return resp, nil | ||||
| } | ||||
|  | ||||
| func (s *AuthOIDCScenario) runTailscaleUp( | ||||
| 	userStr, loginServer string, | ||||
| ) error { | ||||
| 	log.Printf("running tailscale up for user %s", userStr) | ||||
| 	if user, ok := s.users[userStr]; ok { | ||||
| 		for _, client := range user.Clients { | ||||
| 			tsc := client | ||||
| 			user.joinWaitGroup.Go(func() error { | ||||
| 				loginURL, err := tsc.LoginWithURL(loginServer) | ||||
| 				if err != nil { | ||||
| 					log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err) | ||||
| 				} | ||||
|  | ||||
| 				_, err = doLoginURL(tsc.Hostname(), loginURL) | ||||
| 				if err != nil { | ||||
| 					return err | ||||
| 				} | ||||
|  | ||||
| 				return nil | ||||
| 			}) | ||||
|  | ||||
| 			log.Printf("client %s is ready", client.Hostname()) | ||||
| 		} | ||||
|  | ||||
| 		if err := user.joinWaitGroup.Wait(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		for _, client := range user.Clients { | ||||
| 			err := client.WaitForRunning() | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf( | ||||
| 					"%s tailscale node has not reached running: %w", | ||||
| 					client.Hostname(), | ||||
| 					err, | ||||
| 				) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) | ||||
| } | ||||
|  | ||||
| // doLoginURL visits the given login URL and returns the body as a | ||||
| // string. | ||||
| func doLoginURL(hostname string, loginURL *url.URL) (string, error) { | ||||
| 	log.Printf("%s login url: %s\n", hostname, loginURL.String()) | ||||
|  | ||||
| 	var err error | ||||
| 	hc := &http.Client{ | ||||
| 		Transport: LoggingRoundTripper{}, | ||||
| 	} | ||||
| 	hc.Jar, err = cookiejar.New(nil) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("%s failed to create cookiejar	: %w", hostname, err) | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("%s logging in with url", hostname) | ||||
| 	ctx := context.Background() | ||||
| 	req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) | ||||
| 	resp, err := hc.Do(req) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("%s failed to send http request: %w", hostname, err) | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		body, _ := io.ReadAll(resp.Body) | ||||
| 		log.Printf("body: %s", body) | ||||
|  | ||||
| 		return "", fmt.Errorf("%s response code of login request was %w", hostname, err) | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		log.Printf("%s failed to read response body: %s", hostname, err) | ||||
|  | ||||
| 		return "", fmt.Errorf("%s failed to read response body: %w", hostname, err) | ||||
| 	} | ||||
|  | ||||
| 	return string(body), nil | ||||
| } | ||||
|  | ||||
| func (s *AuthOIDCScenario) Shutdown() { | ||||
| 	err := s.pool.Purge(s.mockOIDC) | ||||
| 	if err != nil { | ||||
| 		log.Printf("failed to remove mock oidc container") | ||||
| 	} | ||||
|  | ||||
| 	s.Scenario.Shutdown() | ||||
| } | ||||
|  | ||||
| func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { | ||||
| 	t.Helper() | ||||
|  | ||||
|   | ||||
| @@ -1,47 +1,33 @@ | ||||
| package integration | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/netip" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
|  | ||||
| 	"slices" | ||||
|  | ||||
| 	"github.com/juanfont/headscale/integration/hsic" | ||||
| 	"github.com/samber/lo" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
|  | ||||
| var errParseAuthPage = errors.New("failed to parse auth page") | ||||
|  | ||||
| type AuthWebFlowScenario struct { | ||||
| 	*Scenario | ||||
| } | ||||
|  | ||||
| func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to create scenario: %s", err) | ||||
| 	} | ||||
|  | ||||
| 	scenario := AuthWebFlowScenario{ | ||||
| 		Scenario: baseScenario, | ||||
| 	} | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 		"user2": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 		nil, | ||||
| 		hsic.WithTestName("webauthping"), | ||||
| 		hsic.WithEmbeddedDERPServerOnly(), | ||||
| 		hsic.WithTLS(), | ||||
| @@ -71,20 +57,17 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	scenario := AuthWebFlowScenario{ | ||||
| 		Scenario: baseScenario, | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 		"user2": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		nil, | ||||
| 		hsic.WithTestName("weblogout"), | ||||
| 		hsic.WithTLS(), | ||||
| 	) | ||||
| @@ -137,8 +120,8 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { | ||||
|  | ||||
| 	t.Logf("all clients logged out") | ||||
|  | ||||
| 	for userName := range spec { | ||||
| 		err = scenario.runTailscaleUp(userName, headscale.GetEndpoint()) | ||||
| 	for _, userName := range spec.Users { | ||||
| 		err = scenario.RunTailscaleUpWithURL(userName, headscale.GetEndpoint()) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("failed to run tailscale up (%q): %s", headscale.GetEndpoint(), err) | ||||
| 		} | ||||
| @@ -172,14 +155,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { | ||||
| 		} | ||||
|  | ||||
| 		for _, ip := range ips { | ||||
| 			found := false | ||||
| 			for _, oldIP := range clientIPs[client] { | ||||
| 				if ip == oldIP { | ||||
| 					found = true | ||||
|  | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
| 			found := slices.Contains(clientIPs[client], ip) | ||||
|  | ||||
| 			if !found { | ||||
| 				t.Fatalf( | ||||
| @@ -194,122 +170,3 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { | ||||
|  | ||||
| 	t.Logf("all clients IPs are the same") | ||||
| } | ||||
|  | ||||
| func (s *AuthWebFlowScenario) CreateHeadscaleEnv( | ||||
| 	users map[string]int, | ||||
| 	opts ...hsic.Option, | ||||
| ) error { | ||||
| 	headscale, err := s.Headscale(opts...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = headscale.WaitForRunning() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for userName, clientCount := range users { | ||||
| 		log.Printf("creating user %s with %d clients", userName, clientCount) | ||||
| 		err = s.CreateUser(userName) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = s.CreateTailscaleNodesInUser(userName, "all", clientCount) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = s.runTailscaleUp(userName, headscale.GetEndpoint()) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *AuthWebFlowScenario) runTailscaleUp( | ||||
| 	userStr, loginServer string, | ||||
| ) error { | ||||
| 	log.Printf("running tailscale up for user %q", userStr) | ||||
| 	if user, ok := s.users[userStr]; ok { | ||||
| 		for _, client := range user.Clients { | ||||
| 			c := client | ||||
| 			user.joinWaitGroup.Go(func() error { | ||||
| 				log.Printf("logging %q into %q", c.Hostname(), loginServer) | ||||
| 				loginURL, err := c.LoginWithURL(loginServer) | ||||
| 				if err != nil { | ||||
| 					log.Printf("failed to run tailscale up (%s): %s", c.Hostname(), err) | ||||
|  | ||||
| 					return err | ||||
| 				} | ||||
|  | ||||
| 				err = s.runHeadscaleRegister(userStr, loginURL) | ||||
| 				if err != nil { | ||||
| 					log.Printf("failed to register client (%s): %s", c.Hostname(), err) | ||||
|  | ||||
| 					return err | ||||
| 				} | ||||
|  | ||||
| 				return nil | ||||
| 			}) | ||||
|  | ||||
| 			err := client.WaitForRunning() | ||||
| 			if err != nil { | ||||
| 				log.Printf("error waiting for client %s to be ready: %s", client.Hostname(), err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if err := user.joinWaitGroup.Wait(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		for _, client := range user.Clients { | ||||
| 			err := client.WaitForRunning() | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf("%s failed to up tailscale node: %w", client.Hostname(), err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) | ||||
| } | ||||
|  | ||||
| func (s *AuthWebFlowScenario) runHeadscaleRegister(userStr string, loginURL *url.URL) error { | ||||
| 	body, err := doLoginURL("web-auth-not-set", loginURL) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// see api.go HTML template | ||||
| 	codeSep := strings.Split(string(body), "</code>") | ||||
| 	if len(codeSep) != 2 { | ||||
| 		return errParseAuthPage | ||||
| 	} | ||||
|  | ||||
| 	keySep := strings.Split(codeSep[0], "key ") | ||||
| 	if len(keySep) != 2 { | ||||
| 		return errParseAuthPage | ||||
| 	} | ||||
| 	key := keySep[1] | ||||
| 	log.Printf("registering node %s", key) | ||||
|  | ||||
| 	if headscale, err := s.Headscale(); err == nil { | ||||
| 		_, err = headscale.Execute( | ||||
| 			[]string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			log.Printf("failed to register node: %s", err) | ||||
|  | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable) | ||||
| } | ||||
|   | ||||
| @@ -48,16 +48,15 @@ func TestUserCommand(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 0, | ||||
| 		"user2": 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -247,15 +246,15 @@ func TestPreAuthKeyCommand(t *testing.T) { | ||||
| 	user := "preauthkeyspace" | ||||
| 	count := 3 | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{user}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		user: 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipak")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -388,16 +387,15 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	user := "pre-auth-key-without-exp-user" | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{user}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		user: 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipaknaexp")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -451,16 +449,15 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	user := "pre-auth-key-reus-ephm-user" | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{user}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		user: 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipakresueeph")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -530,17 +527,16 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { | ||||
| 	user1 := "user1" | ||||
| 	user2 := "user2" | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{user1}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		user1: 1, | ||||
| 		user2: 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 		[]tsic.Option{}, | ||||
| 		hsic.WithTestName("clipak"), | ||||
| 		hsic.WithEmbeddedDERPServerOnly(), | ||||
| @@ -551,6 +547,9 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { | ||||
| 	headscale, err := scenario.Headscale() | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	err = headscale.CreateUser(user2) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	var user2Key v1.PreAuthKey | ||||
|  | ||||
| 	err = executeAndUnmarshal( | ||||
| @@ -573,10 +572,15 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { | ||||
| 	) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	listNodes, err := headscale.ListNodes() | ||||
| 	require.Nil(t, err) | ||||
| 	require.Len(t, listNodes, 1) | ||||
| 	assert.Equal(t, user1, listNodes[0].GetUser().GetName()) | ||||
|  | ||||
| 	allClients, err := scenario.ListTailscaleClients() | ||||
| 	assertNoErrListClients(t, err) | ||||
|  | ||||
| 	assert.Len(t, allClients, 1) | ||||
| 	require.Len(t, allClients, 1) | ||||
|  | ||||
| 	client := allClients[0] | ||||
|  | ||||
| @@ -606,12 +610,11 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { | ||||
| 		t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String()) | ||||
| 	} | ||||
|  | ||||
| 	listNodes, err := headscale.ListNodes() | ||||
| 	assert.Nil(t, err) | ||||
| 	assert.Len(t, listNodes, 2) | ||||
|  | ||||
| 	assert.Equal(t, "user1", listNodes[0].GetUser().GetName()) | ||||
| 	assert.Equal(t, "user2", listNodes[1].GetUser().GetName()) | ||||
| 	listNodes, err = headscale.ListNodes() | ||||
| 	require.Nil(t, err) | ||||
| 	require.Len(t, listNodes, 2) | ||||
| 	assert.Equal(t, user1, listNodes[0].GetUser().GetName()) | ||||
| 	assert.Equal(t, user2, listNodes[1].GetUser().GetName()) | ||||
| } | ||||
|  | ||||
| func TestApiKeyCommand(t *testing.T) { | ||||
| @@ -620,16 +623,15 @@ func TestApiKeyCommand(t *testing.T) { | ||||
|  | ||||
| 	count := 5 | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 0, | ||||
| 		"user2": 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -788,15 +790,15 @@ func TestNodeTagCommand(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{"user1"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -977,15 +979,16 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 			spec := ScenarioSpec{ | ||||
| 				NodesPerUser: 1, | ||||
| 				Users:        []string{"user1"}, | ||||
| 			} | ||||
|  | ||||
| 			scenario, err := NewScenario(spec) | ||||
| 			assertNoErr(t, err) | ||||
| 			defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 			spec := map[string]int{ | ||||
| 				"user1": 1, | ||||
| 			} | ||||
|  | ||||
| 			err = scenario.CreateHeadscaleEnv(spec, | ||||
| 			err = scenario.CreateHeadscaleEnv( | ||||
| 				[]tsic.Option{tsic.WithTags([]string{"tag:test"})}, | ||||
| 				hsic.WithTestName("cliadvtags"), | ||||
| 				hsic.WithACLPolicy(tt.policy), | ||||
| @@ -996,7 +999,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) { | ||||
| 			assertNoErr(t, err) | ||||
|  | ||||
| 			// Test list all nodes after added seconds | ||||
| 			resultMachines := make([]*v1.Node, spec["user1"]) | ||||
| 			resultMachines := make([]*v1.Node, spec.NodesPerUser) | ||||
| 			err = executeAndUnmarshal( | ||||
| 				headscale, | ||||
| 				[]string{ | ||||
| @@ -1029,16 +1032,15 @@ func TestNodeCommand(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{"node-user", "other-user"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"node-user":  0, | ||||
| 		"other-user": 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -1269,15 +1271,15 @@ func TestNodeExpireCommand(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{"node-expire-user"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"node-expire-user": 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -1395,15 +1397,15 @@ func TestNodeRenameCommand(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{"node-rename-command"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"node-rename-command": 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -1560,16 +1562,15 @@ func TestNodeMoveCommand(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{"old-user", "new-user"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"old-user": 0, | ||||
| 		"new-user": 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -1721,16 +1722,15 @@ func TestPolicyCommand(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		Users: []string{"user1"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 0, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 		[]tsic.Option{}, | ||||
| 		hsic.WithTestName("clins"), | ||||
| 		hsic.WithConfigEnv(map[string]string{ | ||||
| @@ -1808,16 +1808,16 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{"user1"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 1, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 		[]tsic.Option{}, | ||||
| 		hsic.WithTestName("clins"), | ||||
| 		hsic.WithConfigEnv(map[string]string{ | ||||
|   | ||||
| @@ -24,5 +24,4 @@ type ControlServer interface { | ||||
| 	ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) | ||||
| 	GetCert() []byte | ||||
| 	GetHostname() string | ||||
| 	GetIP() string | ||||
| } | ||||
|   | ||||
| @@ -31,14 +31,15 @@ func TestDERPVerifyEndpoint(t *testing.T) { | ||||
| 	certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	derper, err := scenario.CreateDERPServer("head", | ||||
| 		dsic.WithCACert(certHeadscale), | ||||
| 		dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))), | ||||
| @@ -65,7 +66,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithCACert(derper.GetCert())}, | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{tsic.WithCACert(derper.GetCert())}, | ||||
| 		hsic.WithHostname(hostname), | ||||
| 		hsic.WithPort(headscalePort), | ||||
| 		hsic.WithCustomTLS(certHeadscale, keyHeadscale), | ||||
|   | ||||
| @@ -17,16 +17,16 @@ func TestResolveMagicDNS(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"magicdns1": len(MustTestVersions), | ||||
| 		"magicdns2": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("magicdns")) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 	allClients, err := scenario.ListTailscaleClients() | ||||
| @@ -87,15 +87,15 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"magicdns1": 1, | ||||
| 		"magicdns2": 1, | ||||
| 	} | ||||
|  | ||||
| 	const erPath = "/tmp/extra_records.json" | ||||
|  | ||||
| 	extraRecords := []tailcfg.DNSRecord{ | ||||
| @@ -107,7 +107,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { | ||||
| 	} | ||||
| 	b, _ := json.Marshal(extraRecords) | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{ | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{ | ||||
| 		tsic.WithDockerEntrypoint([]string{ | ||||
| 			"/bin/sh", | ||||
| 			"-c", | ||||
| @@ -364,16 +364,16 @@ func TestValidateResolvConf(t *testing.T) { | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 			spec := ScenarioSpec{ | ||||
| 				NodesPerUser: 3, | ||||
| 				Users:        []string{"user1", "user2"}, | ||||
| 			} | ||||
|  | ||||
| 			scenario, err := NewScenario(spec) | ||||
| 			assertNoErr(t, err) | ||||
| 			defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 			spec := map[string]int{ | ||||
| 				"resolvconf1": 3, | ||||
| 				"resolvconf2": 3, | ||||
| 			} | ||||
|  | ||||
| 			err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("resolvconf"), hsic.WithConfigEnv(tt.conf)) | ||||
| 			err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("resolvconf"), hsic.WithConfigEnv(tt.conf)) | ||||
| 			assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 			allClients, err := scenario.ListTailscaleClients() | ||||
|   | ||||
| @@ -35,7 +35,7 @@ type DERPServerInContainer struct { | ||||
|  | ||||
| 	pool      *dockertest.Pool | ||||
| 	container *dockertest.Resource | ||||
| 	network   *dockertest.Network | ||||
| 	networks  []*dockertest.Network | ||||
|  | ||||
| 	stunPort            int | ||||
| 	derpPort            int | ||||
| @@ -63,22 +63,22 @@ func WithCACert(cert []byte) Option { | ||||
| // isolating the DERPer, will be created. If a network is | ||||
| // passed, the DERPer instance will join the given network. | ||||
| func WithOrCreateNetwork(network *dockertest.Network) Option { | ||||
| 	return func(tsic *DERPServerInContainer) { | ||||
| 	return func(dsic *DERPServerInContainer) { | ||||
| 		if network != nil { | ||||
| 			tsic.network = network | ||||
| 			dsic.networks = append(dsic.networks, network) | ||||
|  | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		network, err := dockertestutil.GetFirstOrCreateNetwork( | ||||
| 			tsic.pool, | ||||
| 			tsic.hostname+"-network", | ||||
| 			dsic.pool, | ||||
| 			dsic.hostname+"-network", | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			log.Fatalf("failed to create network: %s", err) | ||||
| 		} | ||||
|  | ||||
| 		tsic.network = network | ||||
| 		dsic.networks = append(dsic.networks, network) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -107,7 +107,7 @@ func WithExtraHosts(hosts []string) Option { | ||||
| func New( | ||||
| 	pool *dockertest.Pool, | ||||
| 	version string, | ||||
| 	network *dockertest.Network, | ||||
| 	networks []*dockertest.Network, | ||||
| 	opts ...Option, | ||||
| ) (*DERPServerInContainer, error) { | ||||
| 	hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength) | ||||
| @@ -124,7 +124,7 @@ func New( | ||||
| 		version:  version, | ||||
| 		hostname: hostname, | ||||
| 		pool:     pool, | ||||
| 		network:  network, | ||||
| 		networks: networks, | ||||
| 		tlsCert:  tlsCert, | ||||
| 		tlsKey:   tlsKey, | ||||
| 		stunPort: 3478, //nolint | ||||
| @@ -148,7 +148,7 @@ func New( | ||||
|  | ||||
| 	runOptions := &dockertest.RunOptions{ | ||||
| 		Name:       hostname, | ||||
| 		Networks:   []*dockertest.Network{dsic.network}, | ||||
| 		Networks:   dsic.networks, | ||||
| 		ExtraHosts: dsic.withExtraHosts, | ||||
| 		// we currently need to give us some time to inject the certificate further down. | ||||
| 		Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()}, | ||||
|   | ||||
| @@ -1,18 +1,12 @@ | ||||
| package integration | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"log" | ||||
| 	"net/url" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/juanfont/headscale/integration/dockertestutil" | ||||
| 	"github.com/juanfont/headscale/integration/hsic" | ||||
| 	"github.com/juanfont/headscale/integration/tsic" | ||||
| 	"github.com/ory/dockertest/v3" | ||||
| ) | ||||
|  | ||||
| type ClientsSpec struct { | ||||
| @@ -20,21 +14,18 @@ type ClientsSpec struct { | ||||
| 	WebsocketDERP int | ||||
| } | ||||
|  | ||||
| type EmbeddedDERPServerScenario struct { | ||||
| 	*Scenario | ||||
|  | ||||
| 	tsicNetworks map[string]*dockertest.Network | ||||
| } | ||||
|  | ||||
| func TestDERPServerScenario(t *testing.T) { | ||||
| 	spec := map[string]ClientsSpec{ | ||||
| 		"user1": { | ||||
| 			Plain:         len(MustTestVersions), | ||||
| 			WebsocketDERP: 0, | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{"user1", "user2", "user3"}, | ||||
| 		Networks: map[string][]string{ | ||||
| 			"usernet1": {"user1"}, | ||||
| 			"usernet2": {"user2"}, | ||||
| 			"usernet3": {"user3"}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	derpServerScenario(t, spec, func(scenario *EmbeddedDERPServerScenario) { | ||||
| 	derpServerScenario(t, spec, false, func(scenario *Scenario) { | ||||
| 		allClients, err := scenario.ListTailscaleClients() | ||||
| 		assertNoErrListClients(t, err) | ||||
| 		t.Logf("checking %d clients for websocket connections", len(allClients)) | ||||
| @@ -52,14 +43,17 @@ func TestDERPServerScenario(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestDERPServerWebsocketScenario(t *testing.T) { | ||||
| 	spec := map[string]ClientsSpec{ | ||||
| 		"user1": { | ||||
| 			Plain:         0, | ||||
| 			WebsocketDERP: 2, | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{"user1", "user2", "user3"}, | ||||
| 		Networks: map[string][]string{ | ||||
| 			"usernet1": []string{"user1"}, | ||||
| 			"usernet2": []string{"user2"}, | ||||
| 			"usernet3": []string{"user3"}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	derpServerScenario(t, spec, func(scenario *EmbeddedDERPServerScenario) { | ||||
| 	derpServerScenario(t, spec, true, func(scenario *Scenario) { | ||||
| 		allClients, err := scenario.ListTailscaleClients() | ||||
| 		assertNoErrListClients(t, err) | ||||
| 		t.Logf("checking %d clients for websocket connections", len(allClients)) | ||||
| @@ -83,23 +77,22 @@ func TestDERPServerWebsocketScenario(t *testing.T) { | ||||
| //nolint:thelper | ||||
| func derpServerScenario( | ||||
| 	t *testing.T, | ||||
| 	spec map[string]ClientsSpec, | ||||
| 	furtherAssertions ...func(*EmbeddedDERPServerScenario), | ||||
| 	spec ScenarioSpec, | ||||
| 	websocket bool, | ||||
| 	furtherAssertions ...func(*Scenario), | ||||
| ) { | ||||
| 	IntegrationSkip(t) | ||||
| 	// t.Parallel() | ||||
|  | ||||
| 	baseScenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	scenario := EmbeddedDERPServerScenario{ | ||||
| 		Scenario:     baseScenario, | ||||
| 		tsicNetworks: map[string]*dockertest.Network{}, | ||||
| 	} | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		spec, | ||||
| 		[]tsic.Option{ | ||||
| 			tsic.WithWebsocketDERP(websocket), | ||||
| 		}, | ||||
| 		hsic.WithTestName("derpserver"), | ||||
| 		hsic.WithExtraPorts([]string{"3478/udp"}), | ||||
| 		hsic.WithEmbeddedDERPServerOnly(), | ||||
| @@ -185,182 +178,6 @@ func derpServerScenario( | ||||
| 	t.Logf("Run2: %d successful pings out of %d", success, len(allClients)*len(allHostnames)) | ||||
|  | ||||
| 	for _, check := range furtherAssertions { | ||||
| 		check(&scenario) | ||||
| 		check(scenario) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv( | ||||
| 	users map[string]ClientsSpec, | ||||
| 	opts ...hsic.Option, | ||||
| ) error { | ||||
| 	hsServer, err := s.Headscale(opts...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	headscaleEndpoint := hsServer.GetEndpoint() | ||||
| 	headscaleURL, err := url.Parse(headscaleEndpoint) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	headscaleURL.Host = fmt.Sprintf("%s:%s", hsServer.GetHostname(), headscaleURL.Port()) | ||||
|  | ||||
| 	err = hsServer.WaitForRunning() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	log.Printf("headscale server ip address: %s", hsServer.GetIP()) | ||||
|  | ||||
| 	hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for userName, clientCount := range users { | ||||
| 		err = s.CreateUser(userName) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if clientCount.Plain > 0 { | ||||
| 			// Containers that use default DERP config | ||||
| 			err = s.CreateTailscaleIsolatedNodesInUser( | ||||
| 				hash, | ||||
| 				userName, | ||||
| 				"all", | ||||
| 				clientCount.Plain, | ||||
| 			) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if clientCount.WebsocketDERP > 0 { | ||||
| 			// Containers that use DERP-over-WebSocket | ||||
| 			// Note that these clients *must* be built | ||||
| 			// from source, which is currently | ||||
| 			// only done for HEAD. | ||||
| 			err = s.CreateTailscaleIsolatedNodesInUser( | ||||
| 				hash, | ||||
| 				userName, | ||||
| 				tsic.VersionHead, | ||||
| 				clientCount.WebsocketDERP, | ||||
| 				tsic.WithWebsocketDERP(true), | ||||
| 			) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		key, err := s.CreatePreAuthKey(userName, true, false) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = s.RunTailscaleUp(userName, headscaleURL.String(), key.GetKey()) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser( | ||||
| 	hash string, | ||||
| 	userStr string, | ||||
| 	requestedVersion string, | ||||
| 	count int, | ||||
| 	opts ...tsic.Option, | ||||
| ) error { | ||||
| 	hsServer, err := s.Headscale() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if user, ok := s.users[userStr]; ok { | ||||
| 		for clientN := 0; clientN < count; clientN++ { | ||||
| 			networkName := fmt.Sprintf("tsnet-%s-%s-%d", | ||||
| 				hash, | ||||
| 				userStr, | ||||
| 				clientN, | ||||
| 			) | ||||
| 			network, err := dockertestutil.GetFirstOrCreateNetwork( | ||||
| 				s.pool, | ||||
| 				networkName, | ||||
| 			) | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf("failed to create or get %s network: %w", networkName, err) | ||||
| 			} | ||||
|  | ||||
| 			s.tsicNetworks[networkName] = network | ||||
|  | ||||
| 			err = hsServer.ConnectToNetwork(network) | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf("failed to connect headscale to %s network: %w", networkName, err) | ||||
| 			} | ||||
|  | ||||
| 			version := requestedVersion | ||||
| 			if requestedVersion == "all" { | ||||
| 				version = MustTestVersions[clientN%len(MustTestVersions)] | ||||
| 			} | ||||
|  | ||||
| 			cert := hsServer.GetCert() | ||||
|  | ||||
| 			opts = append(opts, | ||||
| 				tsic.WithCACert(cert), | ||||
| 			) | ||||
|  | ||||
| 			user.createWaitGroup.Go(func() error { | ||||
| 				tsClient, err := tsic.New( | ||||
| 					s.pool, | ||||
| 					version, | ||||
| 					network, | ||||
| 					opts..., | ||||
| 				) | ||||
| 				if err != nil { | ||||
| 					return fmt.Errorf( | ||||
| 						"failed to create tailscale (%s) node: %w", | ||||
| 						tsClient.Hostname(), | ||||
| 						err, | ||||
| 					) | ||||
| 				} | ||||
|  | ||||
| 				err = tsClient.WaitForNeedsLogin() | ||||
| 				if err != nil { | ||||
| 					return fmt.Errorf( | ||||
| 						"failed to wait for tailscaled (%s) to need login: %w", | ||||
| 						tsClient.Hostname(), | ||||
| 						err, | ||||
| 					) | ||||
| 				} | ||||
|  | ||||
| 				s.mu.Lock() | ||||
| 				user.Clients[tsClient.Hostname()] = tsClient | ||||
| 				s.mu.Unlock() | ||||
|  | ||||
| 				return nil | ||||
| 			}) | ||||
| 		} | ||||
|  | ||||
| 		if err := user.createWaitGroup.Wait(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Errorf("failed to add tailscale nodes: %w", errNoUserAvailable) | ||||
| } | ||||
|  | ||||
| func (s *EmbeddedDERPServerScenario) Shutdown() { | ||||
| 	for _, network := range s.tsicNetworks { | ||||
| 		err := s.pool.RemoveNetwork(network) | ||||
| 		if err != nil { | ||||
| 			log.Printf("failed to remove DERP network %s", network.Network.Name) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	s.Scenario.Shutdown() | ||||
| } | ||||
|   | ||||
| @@ -28,18 +28,17 @@ func TestPingAllByIP(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 		MaxWait:      dockertestMaxWait(), | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	// TODO(kradalby): it does not look like the user thing works, only second | ||||
| 	// get created? maybe only when many? | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 		"user2": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		[]tsic.Option{}, | ||||
| 		hsic.WithTestName("pingallbyip"), | ||||
| 		hsic.WithEmbeddedDERPServerOnly(), | ||||
| @@ -71,16 +70,16 @@ func TestPingAllByIPPublicDERP(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 		"user2": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		[]tsic.Option{}, | ||||
| 		hsic.WithTestName("pingallbyippubderp"), | ||||
| 	) | ||||
| @@ -121,25 +120,25 @@ func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 		"user2": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	headscale, err := scenario.Headscale(opts...) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 	for userName, clientCount := range spec { | ||||
| 	for _, userName := range spec.Users { | ||||
| 		err = scenario.CreateUser(userName) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("failed to create user %s: %s", userName, err) | ||||
| 		} | ||||
|  | ||||
| 		err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) | ||||
| 		err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) | ||||
| 		} | ||||
| @@ -194,15 +193,15 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 		"user2": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	headscale, err := scenario.Headscale( | ||||
| 		hsic.WithTestName("ephemeral2006"), | ||||
| 		hsic.WithConfigEnv(map[string]string{ | ||||
| @@ -211,13 +210,13 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { | ||||
| 	) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 	for userName, clientCount := range spec { | ||||
| 	for _, userName := range spec.Users { | ||||
| 		err = scenario.CreateUser(userName) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("failed to create user %s: %s", userName, err) | ||||
| 		} | ||||
|  | ||||
| 		err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) | ||||
| 		err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) | ||||
| 		} | ||||
| @@ -287,7 +286,7 @@ func TestEphemeral2006DeletedTooQuickly(t *testing.T) { | ||||
| 	// registered. | ||||
| 	time.Sleep(3 * time.Minute) | ||||
|  | ||||
| 	for userName := range spec { | ||||
| 	for _, userName := range spec.Users { | ||||
| 		nodes, err := headscale.ListNodes(userName) | ||||
| 		if err != nil { | ||||
| 			log.Error(). | ||||
| @@ -308,16 +307,16 @@ func TestPingAllByHostname(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user3": len(MustTestVersions), | ||||
| 		"user4": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyname")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname")) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 	allClients, err := scenario.ListTailscaleClients() | ||||
| @@ -357,15 +356,16 @@ func TestTaildrop(t *testing.T) { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"taildrop": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, | ||||
| 		hsic.WithTestName("taildrop"), | ||||
| 		hsic.WithEmbeddedDERPServerOnly(), | ||||
| 		hsic.WithTLS(), | ||||
| @@ -522,23 +522,22 @@ func TestUpdateHostnameFromClient(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	user := "update-hostname-from-client" | ||||
|  | ||||
| 	hostnames := map[string]string{ | ||||
| 		"1": "user1-host", | ||||
| 		"2": "User2-Host", | ||||
| 		"3": "user3-host", | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 3, | ||||
| 		Users:        []string{"user1"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErrf(t, "failed to create scenario: %s", err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		user: 3, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("updatehostname")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname")) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 	allClients, err := scenario.ListTailscaleClients() | ||||
| @@ -650,15 +649,16 @@ func TestExpireNode(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("expirenode")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode")) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 	allClients, err := scenario.ListTailscaleClients() | ||||
| @@ -684,7 +684,7 @@ func TestExpireNode(t *testing.T) { | ||||
| 		assertNoErr(t, err) | ||||
|  | ||||
| 		// Assert that we have the original count - self | ||||
| 		assert.Len(t, status.Peers(), spec["user1"]-1) | ||||
| 		assert.Len(t, status.Peers(), spec.NodesPerUser-1) | ||||
| 	} | ||||
|  | ||||
| 	headscale, err := scenario.Headscale() | ||||
| @@ -776,15 +776,16 @@ func TestNodeOnlineStatus(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("online")) | ||||
| 	err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online")) | ||||
| 	assertNoErrHeadscaleEnv(t, err) | ||||
|  | ||||
| 	allClients, err := scenario.ListTailscaleClients() | ||||
| @@ -891,18 +892,16 @@ func TestPingAllByIPManyUpDown(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: len(MustTestVersions), | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	// TODO(kradalby): it does not look like the user thing works, only second | ||||
| 	// get created? maybe only when many? | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": len(MustTestVersions), | ||||
| 		"user2": len(MustTestVersions), | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		[]tsic.Option{}, | ||||
| 		hsic.WithTestName("pingallbyipmany"), | ||||
| 		hsic.WithEmbeddedDERPServerOnly(), | ||||
| @@ -973,18 +972,16 @@ func Test2118DeletingOnlineNodePanics(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: 1, | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
|  | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	// TODO(kradalby): it does not look like the user thing works, only second | ||||
| 	// get created? maybe only when many? | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": 1, | ||||
| 		"user2": 1, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		[]tsic.Option{}, | ||||
| 		hsic.WithTestName("deletenocrash"), | ||||
| 		hsic.WithEmbeddedDERPServerOnly(), | ||||
|   | ||||
| @@ -56,7 +56,7 @@ type HeadscaleInContainer struct { | ||||
|  | ||||
| 	pool      *dockertest.Pool | ||||
| 	container *dockertest.Resource | ||||
| 	network   *dockertest.Network | ||||
| 	networks  []*dockertest.Network | ||||
|  | ||||
| 	pgContainer *dockertest.Resource | ||||
|  | ||||
| @@ -268,7 +268,7 @@ func WithTimezone(timezone string) Option { | ||||
| // New returns a new HeadscaleInContainer instance. | ||||
| func New( | ||||
| 	pool *dockertest.Pool, | ||||
| 	network *dockertest.Network, | ||||
| 	networks []*dockertest.Network, | ||||
| 	opts ...Option, | ||||
| ) (*HeadscaleInContainer, error) { | ||||
| 	hash, err := util.GenerateRandomStringDNSSafe(hsicHashLength) | ||||
| @@ -283,7 +283,7 @@ func New( | ||||
| 		port:     headscaleDefaultPort, | ||||
|  | ||||
| 		pool:     pool, | ||||
| 		network: network, | ||||
| 		networks: networks, | ||||
|  | ||||
| 		env:              DefaultConfigEnv(), | ||||
| 		filesInContainer: []fileInContainer{}, | ||||
| @@ -315,7 +315,7 @@ func New( | ||||
| 				Name:       fmt.Sprintf("postgres-%s", hash), | ||||
| 				Repository: "postgres", | ||||
| 				Tag:        "latest", | ||||
| 				Networks:   []*dockertest.Network{network}, | ||||
| 				Networks:   networks, | ||||
| 				Env: []string{ | ||||
| 					"POSTGRES_USER=headscale", | ||||
| 					"POSTGRES_PASSWORD=headscale", | ||||
| @@ -357,7 +357,7 @@ func New( | ||||
| 	runOptions := &dockertest.RunOptions{ | ||||
| 		Name:         hsic.hostname, | ||||
| 		ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...), | ||||
| 		Networks:     []*dockertest.Network{network}, | ||||
| 		Networks:     networks, | ||||
| 		// Cmd:          []string{"headscale", "serve"}, | ||||
| 		// TODO(kradalby): Get rid of this hack, we currently need to give us some | ||||
| 		// to inject the headscale configuration further down. | ||||
| @@ -630,11 +630,6 @@ func (t *HeadscaleInContainer) Execute( | ||||
| 	return stdout, nil | ||||
| } | ||||
|  | ||||
| // GetIP returns the docker container IP as a string. | ||||
| func (t *HeadscaleInContainer) GetIP() string { | ||||
| 	return t.container.GetIPInNetwork(t.network) | ||||
| } | ||||
|  | ||||
| // GetPort returns the docker container port as a string. | ||||
| func (t *HeadscaleInContainer) GetPort() string { | ||||
| 	return fmt.Sprintf("%d", t.port) | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -1,24 +1,37 @@ | ||||
| package integration | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/http/cookiejar" | ||||
| 	"net/netip" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||
| 	"github.com/juanfont/headscale/hscontrol/capver" | ||||
| 	"github.com/juanfont/headscale/hscontrol/types" | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/juanfont/headscale/integration/dockertestutil" | ||||
| 	"github.com/juanfont/headscale/integration/dsic" | ||||
| 	"github.com/juanfont/headscale/integration/hsic" | ||||
| 	"github.com/juanfont/headscale/integration/tsic" | ||||
| 	"github.com/oauth2-proxy/mockoidc" | ||||
| 	"github.com/ory/dockertest/v3" | ||||
| 	"github.com/ory/dockertest/v3/docker" | ||||
| 	"github.com/puzpuzpuz/xsync/v3" | ||||
| 	"github.com/samber/lo" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| @@ -26,6 +39,7 @@ import ( | ||||
| 	xmaps "golang.org/x/exp/maps" | ||||
| 	"golang.org/x/sync/errgroup" | ||||
| 	"tailscale.com/envknob" | ||||
| 	"tailscale.com/util/mak" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -87,32 +101,135 @@ type Scenario struct { | ||||
| 	users map[string]*User | ||||
|  | ||||
| 	pool          *dockertest.Pool | ||||
| 	network *dockertest.Network | ||||
| 	networks      map[string]*dockertest.Network | ||||
| 	mockOIDC      scenarioOIDC | ||||
| 	extraServices map[string][]*dockertest.Resource | ||||
|  | ||||
| 	mu sync.Mutex | ||||
|  | ||||
| 	spec          ScenarioSpec | ||||
| 	userToNetwork map[string]*dockertest.Network | ||||
| } | ||||
|  | ||||
| // ScenarioSpec describes the users, nodes, and network topology to | ||||
| // set up for a given scenario. | ||||
| type ScenarioSpec struct { | ||||
| 	// Users is a list of usernames that will be created. | ||||
| 	// Each created user will get nodes equivalent to NodesPerUser | ||||
| 	Users []string | ||||
|  | ||||
| 	// NodesPerUser is how many nodes should be attached to each user. | ||||
| 	NodesPerUser int | ||||
|  | ||||
| 	// Networks, if set, is the seperate Docker networks that should be | ||||
| 	// created and a list of the users that should be placed in those networks. | ||||
| 	// If not set, a single network will be created and all users+nodes will be | ||||
| 	// added there. | ||||
| 	// Please note that Docker networks are not necessarily routable and | ||||
| 	// connections between them might fall back to DERP. | ||||
| 	Networks map[string][]string | ||||
|  | ||||
| 	// ExtraService, if set, is additional a map of network to additional | ||||
| 	// container services that should be set up. These container services | ||||
| 	// typically dont run Tailscale, e.g. web service to test subnet router. | ||||
| 	ExtraService map[string][]extraServiceFunc | ||||
|  | ||||
| 	// Versions is specific list of versions to use for the test. | ||||
| 	Versions []string | ||||
|  | ||||
| 	// OIDCUsers, if populated, will start a Mock OIDC server and populate | ||||
| 	// the user login stack with the given users. | ||||
| 	// If the NodesPerUser is set, it should align with this list to ensure | ||||
| 	// the correct users are logged in. | ||||
| 	// This is because the MockOIDC server can only serve login | ||||
| 	// requests based on a queue it has been given on startup. | ||||
| 	// We currently only populates it with one login request per user. | ||||
| 	OIDCUsers     []mockoidc.MockUser | ||||
| 	OIDCAccessTTL time.Duration | ||||
|  | ||||
| 	MaxWait time.Duration | ||||
| } | ||||
|  | ||||
| var TestHashPrefix = "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength) | ||||
| var TestDefaultNetwork = TestHashPrefix + "-default" | ||||
|  | ||||
| func prefixedNetworkName(name string) string { | ||||
| 	return TestHashPrefix + "-" + name | ||||
| } | ||||
|  | ||||
| // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with | ||||
| // a set of Users and TailscaleClients. | ||||
| func NewScenario(maxWait time.Duration) (*Scenario, error) { | ||||
| 	hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| func NewScenario(spec ScenarioSpec) (*Scenario, error) { | ||||
| 	pool, err := dockertest.NewPool("") | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("could not connect to docker: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	pool.MaxWait = maxWait | ||||
|  | ||||
| 	networkName := fmt.Sprintf("hs-%s", hash) | ||||
| 	if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" { | ||||
| 		networkName = overrideNetworkName | ||||
| 	if spec.MaxWait == 0 { | ||||
| 		pool.MaxWait = dockertestMaxWait() | ||||
| 	} else { | ||||
| 		pool.MaxWait = spec.MaxWait | ||||
| 	} | ||||
|  | ||||
| 	network, err := dockertestutil.GetFirstOrCreateNetwork(pool, networkName) | ||||
| 	s := &Scenario{ | ||||
| 		controlServers: xsync.NewMapOf[string, ControlServer](), | ||||
| 		users:          make(map[string]*User), | ||||
|  | ||||
| 		pool: pool, | ||||
| 		spec: spec, | ||||
| 	} | ||||
|  | ||||
| 	var userToNetwork map[string]*dockertest.Network | ||||
| 	if spec.Networks != nil || len(spec.Networks) != 0 { | ||||
| 		for name, users := range s.spec.Networks { | ||||
| 			networkName := TestHashPrefix + "-" + name | ||||
| 			network, err := s.AddNetwork(networkName) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			for _, user := range users { | ||||
| 				if n2, ok := userToNetwork[user]; ok { | ||||
| 					return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name) | ||||
| 				} | ||||
| 				mak.Set(&userToNetwork, user, network) | ||||
| 			} | ||||
| 		} | ||||
| 	} else { | ||||
| 		_, err := s.AddNetwork(TestDefaultNetwork) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	for network, extras := range spec.ExtraService { | ||||
| 		for _, extra := range extras { | ||||
| 			svc, err := extra(s, network) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| 			mak.Set(&s.extraServices, prefixedNetworkName(network), append(s.extraServices[prefixedNetworkName(network)], svc)) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	s.userToNetwork = userToNetwork | ||||
|  | ||||
| 	if spec.OIDCUsers != nil && len(spec.OIDCUsers) != 0 { | ||||
| 		ttl := defaultAccessTTL | ||||
| 		if spec.OIDCAccessTTL != 0 { | ||||
| 			ttl = spec.OIDCAccessTTL | ||||
| 		} | ||||
| 		err = s.runMockOIDC(ttl, spec.OIDCUsers) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s, nil | ||||
| } | ||||
|  | ||||
| func (s *Scenario) AddNetwork(name string) (*dockertest.Network, error) { | ||||
| 	network, err := dockertestutil.GetFirstOrCreateNetwork(s.pool, name) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create or get network: %w", err) | ||||
| 	} | ||||
| @@ -120,18 +237,58 @@ func NewScenario(maxWait time.Duration) (*Scenario, error) { | ||||
| 	// We run the test suite in a docker container that calls a couple of endpoints for | ||||
| 	// readiness checks, this ensures that we can run the tests with individual networks | ||||
| 	// and have the client reach the different containers | ||||
| 	err = dockertestutil.AddContainerToNetwork(pool, network, "headscale-test-suite") | ||||
| 	// TODO(kradalby): Can the test-suite be renamed so we can have multiple? | ||||
| 	err = dockertestutil.AddContainerToNetwork(s.pool, network, "headscale-test-suite") | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to add test suite container to network: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	return &Scenario{ | ||||
| 		controlServers: xsync.NewMapOf[string, ControlServer](), | ||||
| 		users:          make(map[string]*User), | ||||
| 	mak.Set(&s.networks, name, network) | ||||
|  | ||||
| 		pool:    pool, | ||||
| 		network: network, | ||||
| 	}, nil | ||||
| 	return network, nil | ||||
| } | ||||
|  | ||||
| func (s *Scenario) Networks() []*dockertest.Network { | ||||
| 	if len(s.networks) == 0 { | ||||
| 		panic("Scenario.Networks called with empty network list") | ||||
| 	} | ||||
| 	return xmaps.Values(s.networks) | ||||
| } | ||||
|  | ||||
| func (s *Scenario) Network(name string) (*dockertest.Network, error) { | ||||
| 	net, ok := s.networks[prefixedNetworkName(name)] | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("no network named: %s", name) | ||||
| 	} | ||||
|  | ||||
| 	return net, nil | ||||
| } | ||||
|  | ||||
| func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { | ||||
| 	net, ok := s.networks[prefixedNetworkName(name)] | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("no network named: %s", name) | ||||
| 	} | ||||
|  | ||||
| 	for _, ipam := range net.Network.IPAM.Config { | ||||
| 		pref, err := netip.ParsePrefix(ipam.Subnet) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		return &pref, nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, fmt.Errorf("no prefix found in network: %s", name) | ||||
| } | ||||
|  | ||||
| func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { | ||||
| 	res, ok := s.extraServices[prefixedNetworkName(name)] | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("no network named: %s", name) | ||||
| 	} | ||||
|  | ||||
| 	return res, nil | ||||
| } | ||||
|  | ||||
| func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { | ||||
| @@ -184,14 +341,27 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if err := s.pool.RemoveNetwork(s.network); err != nil { | ||||
| 		log.Printf("failed to remove network: %s", err) | ||||
| 	for _, svcs := range s.extraServices { | ||||
| 		for _, svc := range svcs { | ||||
| 			err := svc.Close() | ||||
| 			if err != nil { | ||||
| 				log.Printf("failed to tear down service %q: %s", svc.Container.Name, err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// TODO(kradalby): This seem redundant to the previous call | ||||
| 	// if err := s.network.Close(); err != nil { | ||||
| 	// 	return fmt.Errorf("failed to tear down network: %w", err) | ||||
| 	// } | ||||
| 	if s.mockOIDC.r != nil { | ||||
| 		s.mockOIDC.r.Close() | ||||
| 		if err := s.mockOIDC.r.Close(); err != nil { | ||||
| 			log.Printf("failed to tear down oidc server: %s", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	for _, network := range s.networks { | ||||
| 		if err := network.Close(); err != nil { | ||||
| 			log.Printf("failed to tear down network: %s", err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient) | ||||
| @@ -235,7 +405,7 @@ func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) { | ||||
| 		opts = append(opts, hsic.WithPolicyV2()) | ||||
| 	} | ||||
|  | ||||
| 	headscale, err := hsic.New(s.pool, s.network, opts...) | ||||
| 	headscale, err := hsic.New(s.pool, s.Networks(), opts...) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create headscale container: %w", err) | ||||
| 	} | ||||
| @@ -312,7 +482,6 @@ func (s *Scenario) CreateTailscaleNode( | ||||
| 	tsClient, err := tsic.New( | ||||
| 		s.pool, | ||||
| 		version, | ||||
| 		s.network, | ||||
| 		opts..., | ||||
| 	) | ||||
| 	if err != nil { | ||||
| @@ -345,11 +514,15 @@ func (s *Scenario) CreateTailscaleNodesInUser( | ||||
| ) error { | ||||
| 	if user, ok := s.users[userStr]; ok { | ||||
| 		var versions []string | ||||
| 		for i := 0; i < count; i++ { | ||||
| 		for i := range count { | ||||
| 			version := requestedVersion | ||||
| 			if requestedVersion == "all" { | ||||
| 				if s.spec.Versions != nil { | ||||
| 					version = s.spec.Versions[i%len(s.spec.Versions)] | ||||
| 				} else { | ||||
| 					version = MustTestVersions[i%len(MustTestVersions)] | ||||
| 				} | ||||
| 			} | ||||
| 			versions = append(versions, version) | ||||
|  | ||||
| 			headscale, err := s.Headscale() | ||||
| @@ -372,14 +545,12 @@ func (s *Scenario) CreateTailscaleNodesInUser( | ||||
| 				tsClient, err := tsic.New( | ||||
| 					s.pool, | ||||
| 					version, | ||||
| 					s.network, | ||||
| 					opts..., | ||||
| 				) | ||||
| 				s.mu.Unlock() | ||||
| 				if err != nil { | ||||
| 					return fmt.Errorf( | ||||
| 						"failed to create tailscale (%s) node: %w", | ||||
| 						tsClient.Hostname(), | ||||
| 						"failed to create tailscale node: %w", | ||||
| 						err, | ||||
| 					) | ||||
| 				} | ||||
| @@ -492,11 +663,24 @@ func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // CreateHeadscaleEnv is a convenient method returning a complete Headcale | ||||
| // test environment with nodes of all versions, joined to the server with X | ||||
| // users. | ||||
| func (s *Scenario) CreateHeadscaleEnvWithLoginURL( | ||||
| 	tsOpts []tsic.Option, | ||||
| 	opts ...hsic.Option, | ||||
| ) error { | ||||
| 	return s.createHeadscaleEnv(true, tsOpts, opts...) | ||||
| } | ||||
|  | ||||
| func (s *Scenario) CreateHeadscaleEnv( | ||||
| 	users map[string]int, | ||||
| 	tsOpts []tsic.Option, | ||||
| 	opts ...hsic.Option, | ||||
| ) error { | ||||
| 	return s.createHeadscaleEnv(false, tsOpts, opts...) | ||||
| } | ||||
|  | ||||
| // CreateHeadscaleEnv starts the headscale environment and the clients | ||||
| // according to the ScenarioSpec passed to the Scenario. | ||||
| func (s *Scenario) createHeadscaleEnv( | ||||
| 	withURL bool, | ||||
| 	tsOpts []tsic.Option, | ||||
| 	opts ...hsic.Option, | ||||
| ) error { | ||||
| @@ -505,34 +689,188 @@ func (s *Scenario) CreateHeadscaleEnv( | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	usernames := xmaps.Keys(users) | ||||
| 	sort.Strings(usernames) | ||||
| 	for _, username := range usernames { | ||||
| 		clientCount := users[username] | ||||
| 		err = s.CreateUser(username) | ||||
| 	sort.Strings(s.spec.Users) | ||||
| 	for _, user := range s.spec.Users { | ||||
| 		err = s.CreateUser(user) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = s.CreateTailscaleNodesInUser(username, "all", clientCount, tsOpts...) | ||||
| 		var opts []tsic.Option | ||||
| 		if s.userToNetwork != nil { | ||||
| 			opts = append(tsOpts, tsic.WithNetwork(s.userToNetwork[user])) | ||||
| 		} else { | ||||
| 			opts = append(tsOpts, tsic.WithNetwork(s.networks[TestDefaultNetwork])) | ||||
| 		} | ||||
|  | ||||
| 		err = s.CreateTailscaleNodesInUser(user, "all", s.spec.NodesPerUser, opts...) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		key, err := s.CreatePreAuthKey(username, true, false) | ||||
| 		if withURL { | ||||
| 			err = s.RunTailscaleUpWithURL(user, headscale.GetEndpoint()) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} else { | ||||
| 			key, err := s.CreatePreAuthKey(user, true, false) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|  | ||||
| 		err = s.RunTailscaleUp(username, headscale.GetEndpoint(), key.GetKey()) | ||||
| 			err = s.RunTailscaleUp(user, headscale.GetEndpoint(), key.GetKey()) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { | ||||
| 	log.Printf("running tailscale up for user %s", userStr) | ||||
| 	if user, ok := s.users[userStr]; ok { | ||||
| 		for _, client := range user.Clients { | ||||
| 			tsc := client | ||||
| 			user.joinWaitGroup.Go(func() error { | ||||
| 				loginURL, err := tsc.LoginWithURL(loginServer) | ||||
| 				if err != nil { | ||||
| 					log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err) | ||||
| 				} | ||||
|  | ||||
| 				body, err := doLoginURL(tsc.Hostname(), loginURL) | ||||
| 				if err != nil { | ||||
| 					return err | ||||
| 				} | ||||
|  | ||||
| 				// If the URL is not a OIDC URL, then we need to | ||||
| 				// run the register command to fully log in the client. | ||||
| 				if !strings.Contains(loginURL.String(), "/oidc/") { | ||||
| 					s.runHeadscaleRegister(userStr, body) | ||||
| 				} | ||||
|  | ||||
| 				return nil | ||||
| 			}) | ||||
|  | ||||
| 			log.Printf("client %s is ready", client.Hostname()) | ||||
| 		} | ||||
|  | ||||
| 		if err := user.joinWaitGroup.Wait(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		for _, client := range user.Clients { | ||||
| 			err := client.WaitForRunning() | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf( | ||||
| 					"%s tailscale node has not reached running: %w", | ||||
| 					client.Hostname(), | ||||
| 					err, | ||||
| 				) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) | ||||
| } | ||||
|  | ||||
| // doLoginURL visits the given login URL and returns the body as a | ||||
| // string. | ||||
| func doLoginURL(hostname string, loginURL *url.URL) (string, error) { | ||||
| 	log.Printf("%s login url: %s\n", hostname, loginURL.String()) | ||||
|  | ||||
| 	var err error | ||||
| 	hc := &http.Client{ | ||||
| 		Transport: LoggingRoundTripper{}, | ||||
| 	} | ||||
| 	hc.Jar, err = cookiejar.New(nil) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("%s failed to create cookiejar	: %w", hostname, err) | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("%s logging in with url", hostname) | ||||
| 	ctx := context.Background() | ||||
| 	req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) | ||||
| 	resp, err := hc.Do(req) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("%s failed to send http request: %w", hostname, err) | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		body, _ := io.ReadAll(resp.Body) | ||||
| 		log.Printf("body: %s", body) | ||||
|  | ||||
| 		return "", fmt.Errorf("%s response code of login request was %w", hostname, err) | ||||
| 	} | ||||
|  | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	body, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		log.Printf("%s failed to read response body: %s", hostname, err) | ||||
|  | ||||
| 		return "", fmt.Errorf("%s failed to read response body: %w", hostname, err) | ||||
| 	} | ||||
|  | ||||
| 	return string(body), nil | ||||
| } | ||||
|  | ||||
| var errParseAuthPage = errors.New("failed to parse auth page") | ||||
|  | ||||
| func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { | ||||
| 	// see api.go HTML template | ||||
| 	codeSep := strings.Split(string(body), "</code>") | ||||
| 	if len(codeSep) != 2 { | ||||
| 		return errParseAuthPage | ||||
| 	} | ||||
|  | ||||
| 	keySep := strings.Split(codeSep[0], "key ") | ||||
| 	if len(keySep) != 2 { | ||||
| 		return errParseAuthPage | ||||
| 	} | ||||
| 	key := keySep[1] | ||||
| 	log.Printf("registering node %s", key) | ||||
|  | ||||
| 	if headscale, err := s.Headscale(); err == nil { | ||||
| 		_, err = headscale.Execute( | ||||
| 			[]string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			log.Printf("failed to register node: %s", err) | ||||
|  | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable) | ||||
| } | ||||
|  | ||||
| type LoggingRoundTripper struct{} | ||||
|  | ||||
| func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { | ||||
| 	noTls := &http.Transport{ | ||||
| 		TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint | ||||
| 	} | ||||
| 	resp, err := noTls.RoundTrip(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("---") | ||||
| 	log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String()) | ||||
| 	log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies()) | ||||
|  | ||||
| 	return resp, nil | ||||
| } | ||||
|  | ||||
| // GetIPs returns all netip.Addr of TailscaleClients associated with a User | ||||
| // in a Scenario. | ||||
| func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) { | ||||
| @@ -670,7 +1008,7 @@ func (s *Scenario) WaitForTailscaleLogout() error { | ||||
|  | ||||
| // CreateDERPServer creates a new DERP server in a container. | ||||
| func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) { | ||||
| 	derp, err := dsic.New(s.pool, version, s.network, opts...) | ||||
| 	derp, err := dsic.New(s.pool, version, s.Networks(), opts...) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to create DERP server: %w", err) | ||||
| 	} | ||||
| @@ -684,3 +1022,216 @@ func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic. | ||||
|  | ||||
| 	return derp, nil | ||||
| } | ||||
|  | ||||
| type scenarioOIDC struct { | ||||
| 	r   *dockertest.Resource | ||||
| 	cfg *types.OIDCConfig | ||||
| } | ||||
|  | ||||
| func (o *scenarioOIDC) Issuer() string { | ||||
| 	if o.cfg == nil { | ||||
| 		panic("OIDC has not been created") | ||||
| 	} | ||||
|  | ||||
| 	return o.cfg.Issuer | ||||
| } | ||||
|  | ||||
| func (o *scenarioOIDC) ClientSecret() string { | ||||
| 	if o.cfg == nil { | ||||
| 		panic("OIDC has not been created") | ||||
| 	} | ||||
|  | ||||
| 	return o.cfg.ClientSecret | ||||
| } | ||||
|  | ||||
| func (o *scenarioOIDC) ClientID() string { | ||||
| 	if o.cfg == nil { | ||||
| 		panic("OIDC has not been created") | ||||
| 	} | ||||
|  | ||||
| 	return o.cfg.ClientID | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	dockerContextPath      = "../." | ||||
| 	hsicOIDCMockHashLength = 6 | ||||
| 	defaultAccessTTL       = 10 * time.Minute | ||||
| ) | ||||
|  | ||||
| var errStatusCodeNotOK = errors.New("status code not OK") | ||||
|  | ||||
| func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) error { | ||||
| 	port, err := dockertestutil.RandomFreeHostPort() | ||||
| 	if err != nil { | ||||
| 		log.Fatalf("could not find an open port: %s", err) | ||||
| 	} | ||||
| 	portNotation := fmt.Sprintf("%d/tcp", port) | ||||
|  | ||||
| 	hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) | ||||
|  | ||||
| 	hostname := fmt.Sprintf("hs-oidcmock-%s", hash) | ||||
|  | ||||
| 	usersJSON, err := json.Marshal(users) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	mockOidcOptions := &dockertest.RunOptions{ | ||||
| 		Name:         hostname, | ||||
| 		Cmd:          []string{"headscale", "mockoidc"}, | ||||
| 		ExposedPorts: []string{portNotation}, | ||||
| 		PortBindings: map[docker.Port][]docker.PortBinding{ | ||||
| 			docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, | ||||
| 		}, | ||||
| 		Networks: s.Networks(), | ||||
| 		Env: []string{ | ||||
| 			fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname), | ||||
| 			fmt.Sprintf("MOCKOIDC_PORT=%d", port), | ||||
| 			"MOCKOIDC_CLIENT_ID=superclient", | ||||
| 			"MOCKOIDC_CLIENT_SECRET=supersecret", | ||||
| 			fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), | ||||
| 			fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	headscaleBuildOptions := &dockertest.BuildOptions{ | ||||
| 		Dockerfile: hsic.IntegrationTestDockerFileName, | ||||
| 		ContextDir: dockerContextPath, | ||||
| 	} | ||||
|  | ||||
| 	err = s.pool.RemoveContainerByName(hostname) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	s.mockOIDC = scenarioOIDC{} | ||||
|  | ||||
| 	if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( | ||||
| 		headscaleBuildOptions, | ||||
| 		mockOidcOptions, | ||||
| 		dockertestutil.DockerRestartPolicy); err == nil { | ||||
| 		s.mockOIDC.r = pmockoidc | ||||
| 	} else { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// headscale needs to set up the provider with a specific | ||||
| 	// IP addr to ensure we get the correct config from the well-known | ||||
| 	// endpoint. | ||||
| 	network := s.Networks()[0] | ||||
| 	ipAddr := s.mockOIDC.r.GetIPInNetwork(network) | ||||
|  | ||||
| 	log.Println("Waiting for headscale mock oidc to be ready for tests") | ||||
| 	hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) | ||||
|  | ||||
| 	if err := s.pool.Retry(func() error { | ||||
| 		oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) | ||||
| 		httpClient := &http.Client{} | ||||
| 		ctx := context.Background() | ||||
| 		req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil) | ||||
| 		resp, err := httpClient.Do(req) | ||||
| 		if err != nil { | ||||
| 			log.Printf("headscale mock OIDC tests is not ready: %s\n", err) | ||||
|  | ||||
| 			return err | ||||
| 		} | ||||
| 		defer resp.Body.Close() | ||||
|  | ||||
| 		if resp.StatusCode != http.StatusOK { | ||||
| 			return errStatusCodeNotOK | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	}); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	s.mockOIDC.cfg = &types.OIDCConfig{ | ||||
| 		Issuer: fmt.Sprintf( | ||||
| 			"http://%s/oidc", | ||||
| 			hostEndpoint, | ||||
| 		), | ||||
| 		ClientID:                   "superclient", | ||||
| 		ClientSecret:               "supersecret", | ||||
| 		OnlyStartIfOIDCIsAvailable: true, | ||||
| 	} | ||||
|  | ||||
| 	log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type extraServiceFunc func(*Scenario, string) (*dockertest.Resource, error) | ||||
|  | ||||
| func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { | ||||
| 	// port, err := dockertestutil.RandomFreeHostPort() | ||||
| 	// if err != nil { | ||||
| 	// 	log.Fatalf("could not find an open port: %s", err) | ||||
| 	// } | ||||
| 	// portNotation := fmt.Sprintf("%d/tcp", port) | ||||
|  | ||||
| 	hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) | ||||
|  | ||||
| 	hostname := fmt.Sprintf("hs-webservice-%s", hash) | ||||
|  | ||||
| 	network, ok := s.networks[prefixedNetworkName(networkName)] | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("network does not exist: %s", networkName) | ||||
| 	} | ||||
|  | ||||
| 	webOpts := &dockertest.RunOptions{ | ||||
| 		Name: hostname, | ||||
| 		Cmd:  []string{"/bin/sh", "-c", "cd / ; python3 -m http.server --bind :: 80"}, | ||||
| 		// ExposedPorts: []string{portNotation}, | ||||
| 		// PortBindings: map[docker.Port][]docker.PortBinding{ | ||||
| 		// 	docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, | ||||
| 		// }, | ||||
| 		Networks: []*dockertest.Network{network}, | ||||
| 		Env:      []string{}, | ||||
| 	} | ||||
|  | ||||
| 	webBOpts := &dockertest.BuildOptions{ | ||||
| 		Dockerfile: hsic.IntegrationTestDockerFileName, | ||||
| 		ContextDir: dockerContextPath, | ||||
| 	} | ||||
|  | ||||
| 	web, err := s.pool.BuildAndRunWithBuildOptions( | ||||
| 		webBOpts, | ||||
| 		webOpts, | ||||
| 		dockertestutil.DockerRestartPolicy) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// headscale needs to set up the provider with a specific | ||||
| 	// IP addr to ensure we get the correct config from the well-known | ||||
| 	// endpoint. | ||||
| 	// ipAddr := web.GetIPInNetwork(network) | ||||
|  | ||||
| 	// log.Println("Waiting for headscale mock oidc to be ready for tests") | ||||
| 	// hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) | ||||
|  | ||||
| 	// if err := s.pool.Retry(func() error { | ||||
| 	// 	oidcConfigURL := fmt.Sprintf("http://%s/etc/hostname", hostEndpoint) | ||||
| 	// 	httpClient := &http.Client{} | ||||
| 	// 	ctx := context.Background() | ||||
| 	// 	req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil) | ||||
| 	// 	resp, err := httpClient.Do(req) | ||||
| 	// 	if err != nil { | ||||
| 	// 		log.Printf("headscale mock OIDC tests is not ready: %s\n", err) | ||||
|  | ||||
| 	// 		return err | ||||
| 	// 	} | ||||
| 	// 	defer resp.Body.Close() | ||||
|  | ||||
| 	// 	if resp.StatusCode != http.StatusOK { | ||||
| 	// 		return errStatusCodeNotOK | ||||
| 	// 	} | ||||
|  | ||||
| 	// 	return nil | ||||
| 	// }); err != nil { | ||||
| 	// 	return err | ||||
| 	// } | ||||
|  | ||||
| 	return web, nil | ||||
| } | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/juanfont/headscale/integration/dockertestutil" | ||||
| 	"github.com/juanfont/headscale/integration/tsic" | ||||
| ) | ||||
|  | ||||
| // This file is intended to "test the test framework", by proxy it will also test | ||||
| @@ -33,7 +34,7 @@ func TestHeadscale(t *testing.T) { | ||||
|  | ||||
| 	user := "test-space" | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	scenario, err := NewScenario(ScenarioSpec{}) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| @@ -68,38 +69,6 @@ func TestHeadscale(t *testing.T) { | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // If subtests are parallel, then they will start before setup is run. | ||||
| // This might mean we approach setup slightly wrong, but for now, ignore | ||||
| // the linter | ||||
| // nolint:tparallel | ||||
| func TestCreateTailscale(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	user := "only-create-containers" | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| 	scenario.users[user] = &User{ | ||||
| 		Clients: make(map[string]TailscaleClient), | ||||
| 	} | ||||
|  | ||||
| 	t.Run("create-tailscale", func(t *testing.T) { | ||||
| 		err := scenario.CreateTailscaleNodesInUser(user, "all", 3) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("failed to add tailscale nodes: %s", err) | ||||
| 		} | ||||
|  | ||||
| 		if clients := len(scenario.users[user].Clients); clients != 3 { | ||||
| 			t.Fatalf("wrong number of tailscale clients: %d != %d", clients, 3) | ||||
| 		} | ||||
|  | ||||
| 		// TODO(kradalby): Test "all" version logic | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // If subtests are parallel, then they will start before setup is run. | ||||
| // This might mean we approach setup slightly wrong, but for now, ignore | ||||
| // the linter | ||||
| @@ -114,7 +83,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { | ||||
|  | ||||
| 	count := 1 | ||||
|  | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 	scenario, err := NewScenario(ScenarioSpec{}) | ||||
| 	assertNoErr(t, err) | ||||
| 	defer scenario.ShutdownAssertNoPanics(t) | ||||
|  | ||||
| @@ -142,7 +111,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("create-tailscale", func(t *testing.T) { | ||||
| 		err := scenario.CreateTailscaleNodesInUser(user, "unstable", count) | ||||
| 		err := scenario.CreateTailscaleNodesInUser(user, "unstable", count, tsic.WithNetwork(scenario.networks[TestDefaultNetwork])) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("failed to add tailscale nodes: %s", err) | ||||
| 		} | ||||
|   | ||||
| @@ -50,15 +50,15 @@ var retry = func(times int, sleepInterval time.Duration, | ||||
|  | ||||
| func sshScenario(t *testing.T, policy *policyv1.ACLPolicy, clientsPerUser int) *Scenario { | ||||
| 	t.Helper() | ||||
| 	scenario, err := NewScenario(dockertestMaxWait()) | ||||
|  | ||||
| 	spec := ScenarioSpec{ | ||||
| 		NodesPerUser: clientsPerUser, | ||||
| 		Users:        []string{"user1", "user2"}, | ||||
| 	} | ||||
| 	scenario, err := NewScenario(spec) | ||||
| 	assertNoErr(t, err) | ||||
|  | ||||
| 	spec := map[string]int{ | ||||
| 		"user1": clientsPerUser, | ||||
| 		"user2": clientsPerUser, | ||||
| 	} | ||||
|  | ||||
| 	err = scenario.CreateHeadscaleEnv(spec, | ||||
| 	err = scenario.CreateHeadscaleEnv( | ||||
| 		[]tsic.Option{ | ||||
| 			tsic.WithSSH(), | ||||
|  | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"net/netip" | ||||
| 	"net/url" | ||||
|  | ||||
| 	"github.com/juanfont/headscale/hscontrol/util" | ||||
| 	"github.com/juanfont/headscale/integration/dockertestutil" | ||||
| 	"github.com/juanfont/headscale/integration/tsic" | ||||
| 	"tailscale.com/ipn/ipnstate" | ||||
| @@ -27,6 +28,9 @@ type TailscaleClient interface { | ||||
| 	Up() error | ||||
| 	Down() error | ||||
| 	IPs() ([]netip.Addr, error) | ||||
| 	MustIPs() []netip.Addr | ||||
| 	MustIPv4() netip.Addr | ||||
| 	MustIPv6() netip.Addr | ||||
| 	FQDN() (string, error) | ||||
| 	Status(...bool) (*ipnstate.Status, error) | ||||
| 	MustStatus() *ipnstate.Status | ||||
| @@ -38,6 +42,7 @@ type TailscaleClient interface { | ||||
| 	WaitForPeers(expected int) error | ||||
| 	Ping(hostnameOrIP string, opts ...tsic.PingOption) error | ||||
| 	Curl(url string, opts ...tsic.CurlOption) (string, error) | ||||
| 	Traceroute(netip.Addr) (util.Traceroute, error) | ||||
| 	ID() string | ||||
| 	ReadFile(path string) ([]byte, error) | ||||
|  | ||||
|   | ||||
| @@ -13,6 +13,7 @@ import ( | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"runtime/debug" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -81,6 +82,7 @@ type TailscaleInContainer struct { | ||||
| 	workdir           string | ||||
| 	netfilter         string | ||||
| 	extraLoginArgs    []string | ||||
| 	withAcceptRoutes  bool | ||||
|  | ||||
| 	// build options, solely for HEAD | ||||
| 	buildConfig TailscaleInContainerBuildConfig | ||||
| @@ -101,26 +103,10 @@ func WithCACert(cert []byte) Option { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithOrCreateNetwork sets the Docker container network to use with | ||||
| // the Tailscale instance, if the parameter is nil, a new network, | ||||
| // isolating the TailscaleClient, will be created. If a network is | ||||
| // passed, the Tailscale instance will join the given network. | ||||
| func WithOrCreateNetwork(network *dockertest.Network) Option { | ||||
| // WithNetwork sets the Docker container network to use with | ||||
| // the Tailscale instance. | ||||
| func WithNetwork(network *dockertest.Network) Option { | ||||
| 	return func(tsic *TailscaleInContainer) { | ||||
| 		if network != nil { | ||||
| 			tsic.network = network | ||||
|  | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		network, err := dockertestutil.GetFirstOrCreateNetwork( | ||||
| 			tsic.pool, | ||||
| 			fmt.Sprintf("%s-network", tsic.hostname), | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			log.Fatalf("failed to create network: %s", err) | ||||
| 		} | ||||
|  | ||||
| 		tsic.network = network | ||||
| 	} | ||||
| } | ||||
| @@ -212,11 +198,17 @@ func WithExtraLoginArgs(args []string) Option { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithAcceptRoutes tells the node to accept incomming routes. | ||||
| func WithAcceptRoutes() Option { | ||||
| 	return func(tsic *TailscaleInContainer) { | ||||
| 		tsic.withAcceptRoutes = true | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // New returns a new TailscaleInContainer instance. | ||||
| func New( | ||||
| 	pool *dockertest.Pool, | ||||
| 	version string, | ||||
| 	network *dockertest.Network, | ||||
| 	opts ...Option, | ||||
| ) (*TailscaleInContainer, error) { | ||||
| 	hash, err := util.GenerateRandomStringDNSSafe(tsicHashLength) | ||||
| @@ -231,7 +223,6 @@ func New( | ||||
| 		hostname: hostname, | ||||
|  | ||||
| 		pool: pool, | ||||
| 		network: network, | ||||
|  | ||||
| 		withEntrypoint: []string{ | ||||
| 			"/bin/sh", | ||||
| @@ -244,6 +235,10 @@ func New( | ||||
| 		opt(tsic) | ||||
| 	} | ||||
|  | ||||
| 	if tsic.network == nil { | ||||
| 		return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack())) | ||||
| 	} | ||||
|  | ||||
| 	tailscaleOptions := &dockertest.RunOptions{ | ||||
| 		Name:       hostname, | ||||
| 		Networks:   []*dockertest.Network{tsic.network}, | ||||
| @@ -442,7 +437,7 @@ func (t *TailscaleInContainer) Login( | ||||
| 		"--login-server=" + loginServer, | ||||
| 		"--authkey=" + authKey, | ||||
| 		"--hostname=" + t.hostname, | ||||
| 		"--accept-routes=false", | ||||
| 		fmt.Sprintf("--accept-routes=%t", t.withAcceptRoutes), | ||||
| 	} | ||||
|  | ||||
| 	if t.extraLoginArgs != nil { | ||||
| @@ -597,6 +592,33 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { | ||||
| 	return ips, nil | ||||
| } | ||||
|  | ||||
| func (t *TailscaleInContainer) MustIPs() []netip.Addr { | ||||
| 	ips, err := t.IPs() | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
|  | ||||
| 	return ips | ||||
| } | ||||
|  | ||||
| func (t *TailscaleInContainer) MustIPv4() netip.Addr { | ||||
| 	for _, ip := range t.MustIPs() { | ||||
| 		if ip.Is4() { | ||||
| 			return ip | ||||
| 		} | ||||
| 	} | ||||
| 	panic("no ipv4 found") | ||||
| } | ||||
|  | ||||
| func (t *TailscaleInContainer) MustIPv6() netip.Addr { | ||||
| 	for _, ip := range t.MustIPs() { | ||||
| 		if ip.Is6() { | ||||
| 			return ip | ||||
| 		} | ||||
| 	} | ||||
| 	panic("no ipv6 found") | ||||
| } | ||||
|  | ||||
| // Status returns the ipnstate.Status of the Tailscale instance. | ||||
| func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) { | ||||
| 	command := []string{ | ||||
| @@ -992,6 +1014,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err | ||||
| 		), | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		log.Printf("command: %v", command) | ||||
| 		log.Printf( | ||||
| 			"failed to run ping command from %s to %s, err: %s", | ||||
| 			t.Hostname(), | ||||
| @@ -1108,6 +1131,26 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err | ||||
| 	return result, nil | ||||
| } | ||||
|  | ||||
| func (t *TailscaleInContainer) Traceroute(ip netip.Addr) (util.Traceroute, error) { | ||||
| 	command := []string{ | ||||
| 		"traceroute", | ||||
| 		ip.String(), | ||||
| 	} | ||||
|  | ||||
| 	var result util.Traceroute | ||||
| 	stdout, stderr, err := t.Execute(command) | ||||
| 	if err != nil { | ||||
| 		return result, err | ||||
| 	} | ||||
|  | ||||
| 	result, err = util.ParseTraceroute(stdout + stderr) | ||||
| 	if err != nil { | ||||
| 		return result, err | ||||
| 	} | ||||
|  | ||||
| 	return result, nil | ||||
| } | ||||
|  | ||||
| // WriteFile save file inside the Tailscale container. | ||||
| func (t *TailscaleInContainer) WriteFile(path string, data []byte) error { | ||||
| 	return integrationutil.WriteFileToContainer(t.pool, t.container, path, data) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user