diff --git a/cmd/headscale/cli/auth.go b/cmd/headscale/cli/auth.go new file mode 100644 index 00000000..8a5476dd --- /dev/null +++ b/cmd/headscale/cli/auth.go @@ -0,0 +1,93 @@ +package cli + +import ( + "context" + "fmt" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(authCmd) + + authRegisterCmd.Flags().StringP("user", "u", "", "User") + authRegisterCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authRegisterCmd, "user", "auth-id") + authCmd.AddCommand(authRegisterCmd) + + authApproveCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authApproveCmd, "auth-id") + authCmd.AddCommand(authApproveCmd) + + authRejectCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authRejectCmd, "auth-id") + authCmd.AddCommand(authRejectCmd) +} + +var authCmd = &cobra.Command{ + Use: "auth", + Short: "Manage node authentication and approval", +} + +var authRegisterCmd = &cobra.Command{ + Use: "register", + Short: "Register a node to your network", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + user, _ := cmd.Flags().GetString("user") + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthRegisterRequest{ + AuthId: authID, + User: user, + } + + response, err := client.AuthRegister(ctx, request) + if err != nil { + return fmt.Errorf("registering node: %w", err) + } + + return printOutput( + cmd, + response.GetNode(), + fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName())) + }), +} + +var authApproveCmd = &cobra.Command{ + Use: "approve", + Short: "Approve a pending authentication request", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthApproveRequest{ + AuthId: authID, + } + + response, err := client.AuthApprove(ctx, request) + if err != nil { + return fmt.Errorf("approving auth request: %w", err) + } + + return printOutput(cmd, response, "Auth request approved") + }), +} + +var authRejectCmd = &cobra.Command{ + Use: "reject", + Short: "Reject a pending authentication request", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthRejectRequest{ + AuthId: authID, + } + + response, err := client.AuthReject(ctx, request) + if err != nil { + return fmt.Errorf("rejecting auth request: %w", err) + } + + return printOutput(cmd, response, "Auth request rejected") + }), +} diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 930efc29..0ed7330c 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -64,8 +64,9 @@ var nodeCmd = &cobra.Command{ } var registerNodeCmd = &cobra.Command{ - Use: "register", - Short: "Registers a node to your network", + Use: "register", + Short: "Registers a node to your network", + Deprecated: "use 'headscale auth register --auth-id --user ' instead", RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { user, _ := cmd.Flags().GetString("user") registrationID, _ := cmd.Flags().GetString("key") diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index c0fd5a3e..567efc8e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -856,4 +856,59 @@ func (api headscaleV1APIServer) Health( return response, healthErr } +func (api headscaleV1APIServer) AuthRegister( + ctx context.Context, + request *v1.AuthRegisterRequest, +) (*v1.AuthRegisterResponse, error) { + resp, err := api.RegisterNode(ctx, &v1.RegisterNodeRequest{ + Key: request.GetAuthId(), + User: request.GetUser(), + }) + if err != nil { + return nil, err + } + + return &v1.AuthRegisterResponse{Node: resp.GetNode()}, nil +} + +func (api headscaleV1APIServer) AuthApprove( + ctx context.Context, + request *v1.AuthApproveRequest, +) (*v1.AuthApproveResponse, error) { + authID, err := types.AuthIDFromString(request.GetAuthId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid auth_id: %v", err) + } + + authReq, ok := api.h.state.GetAuthCacheEntry(authID) + if !ok { + return nil, status.Errorf(codes.NotFound, "no pending auth session for auth_id %s", authID) + } + + authReq.FinishAuth(types.AuthVerdict{}) + + return &v1.AuthApproveResponse{}, nil +} + +func (api headscaleV1APIServer) AuthReject( + ctx context.Context, + request *v1.AuthRejectRequest, +) (*v1.AuthRejectResponse, error) { + authID, err := types.AuthIDFromString(request.GetAuthId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid auth_id: %v", err) + } + + authReq, ok := api.h.state.GetAuthCacheEntry(authID) + if !ok { + return nil, status.Errorf(codes.NotFound, "no pending auth session for auth_id %s", authID) + } + + authReq.FinishAuth(types.AuthVerdict{ + Err: fmt.Errorf("auth request rejected"), + }) + + return &v1.AuthRejectResponse{}, nil +} + func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 9f544f8d..57469ce0 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -282,7 +282,7 @@ func (a *AuthProviderWeb) AuthHandler( } func authIDFromRequest(req *http.Request) (types.AuthID, error) { - registrationId, err := urlParam[types.AuthID](req, "auth_id") + raw, err := urlParam[string](req, "auth_id") if err != nil { return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) } @@ -290,7 +290,7 @@ func authIDFromRequest(req *http.Request) (types.AuthID, error) { // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render // the template and log an error. - err = registrationId.Validate() + registrationId, err := types.AuthIDFromString(raw) if err != nil { return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) }