From 48cc98b787f1155cc95e1d79f83d361099152987 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 24 Feb 2026 18:51:11 +0000 Subject: [PATCH] hscontrol, cli: add auth register and approve commands Implement AuthRegister and AuthApprove gRPC handlers and add corresponding CLI commands (headscale auth register, approve, reject) for managing pending auth requests including SSH check approvals. Updates #1850 --- cmd/headscale/cli/auth.go | 93 ++++++++++++++++++++++++++++++++++++++ cmd/headscale/cli/nodes.go | 5 +- hscontrol/grpcv1.go | 55 ++++++++++++++++++++++ hscontrol/handlers.go | 4 +- 4 files changed, 153 insertions(+), 4 deletions(-) create mode 100644 cmd/headscale/cli/auth.go diff --git a/cmd/headscale/cli/auth.go b/cmd/headscale/cli/auth.go new file mode 100644 index 00000000..8a5476dd --- /dev/null +++ b/cmd/headscale/cli/auth.go @@ -0,0 +1,93 @@ +package cli + +import ( + "context" + "fmt" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(authCmd) + + authRegisterCmd.Flags().StringP("user", "u", "", "User") + authRegisterCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authRegisterCmd, "user", "auth-id") + authCmd.AddCommand(authRegisterCmd) + + authApproveCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authApproveCmd, "auth-id") + authCmd.AddCommand(authApproveCmd) + + authRejectCmd.Flags().String("auth-id", "", "Auth ID") + mustMarkRequired(authRejectCmd, "auth-id") + authCmd.AddCommand(authRejectCmd) +} + +var authCmd = &cobra.Command{ + Use: "auth", + Short: "Manage node authentication and approval", +} + +var authRegisterCmd = &cobra.Command{ + Use: "register", + Short: "Register a node to your network", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + user, _ := cmd.Flags().GetString("user") + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthRegisterRequest{ + AuthId: authID, + User: user, + } + + response, err := client.AuthRegister(ctx, request) + if err != nil { + return fmt.Errorf("registering node: %w", err) + } + + return printOutput( + cmd, + response.GetNode(), + fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName())) + }), +} + +var authApproveCmd = &cobra.Command{ + Use: "approve", + Short: "Approve a pending authentication request", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthApproveRequest{ + AuthId: authID, + } + + response, err := client.AuthApprove(ctx, request) + if err != nil { + return fmt.Errorf("approving auth request: %w", err) + } + + return printOutput(cmd, response, "Auth request approved") + }), +} + +var authRejectCmd = &cobra.Command{ + Use: "reject", + Short: "Reject a pending authentication request", + RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { + authID, _ := cmd.Flags().GetString("auth-id") + + request := &v1.AuthRejectRequest{ + AuthId: authID, + } + + response, err := client.AuthReject(ctx, request) + if err != nil { + return fmt.Errorf("rejecting auth request: %w", err) + } + + return printOutput(cmd, response, "Auth request rejected") + }), +} diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 930efc29..0ed7330c 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -64,8 +64,9 @@ var nodeCmd = &cobra.Command{ } var registerNodeCmd = &cobra.Command{ - Use: "register", - Short: "Registers a node to your network", + Use: "register", + Short: "Registers a node to your network", + Deprecated: "use 'headscale auth register --auth-id --user ' instead", RunE: grpcRunE(func(ctx context.Context, client v1.HeadscaleServiceClient, cmd *cobra.Command, args []string) error { user, _ := cmd.Flags().GetString("user") registrationID, _ := cmd.Flags().GetString("key") diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index c0fd5a3e..567efc8e 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -856,4 +856,59 @@ func (api headscaleV1APIServer) Health( return response, healthErr } +func (api headscaleV1APIServer) AuthRegister( + ctx context.Context, + request *v1.AuthRegisterRequest, +) (*v1.AuthRegisterResponse, error) { + resp, err := api.RegisterNode(ctx, &v1.RegisterNodeRequest{ + Key: request.GetAuthId(), + User: request.GetUser(), + }) + if err != nil { + return nil, err + } + + return &v1.AuthRegisterResponse{Node: resp.GetNode()}, nil +} + +func (api headscaleV1APIServer) AuthApprove( + ctx context.Context, + request *v1.AuthApproveRequest, +) (*v1.AuthApproveResponse, error) { + authID, err := types.AuthIDFromString(request.GetAuthId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid auth_id: %v", err) + } + + authReq, ok := api.h.state.GetAuthCacheEntry(authID) + if !ok { + return nil, status.Errorf(codes.NotFound, "no pending auth session for auth_id %s", authID) + } + + authReq.FinishAuth(types.AuthVerdict{}) + + return &v1.AuthApproveResponse{}, nil +} + +func (api headscaleV1APIServer) AuthReject( + ctx context.Context, + request *v1.AuthRejectRequest, +) (*v1.AuthRejectResponse, error) { + authID, err := types.AuthIDFromString(request.GetAuthId()) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid auth_id: %v", err) + } + + authReq, ok := api.h.state.GetAuthCacheEntry(authID) + if !ok { + return nil, status.Errorf(codes.NotFound, "no pending auth session for auth_id %s", authID) + } + + authReq.FinishAuth(types.AuthVerdict{ + Err: fmt.Errorf("auth request rejected"), + }) + + return &v1.AuthRejectResponse{}, nil +} + func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 9f544f8d..57469ce0 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -282,7 +282,7 @@ func (a *AuthProviderWeb) AuthHandler( } func authIDFromRequest(req *http.Request) (types.AuthID, error) { - registrationId, err := urlParam[types.AuthID](req, "auth_id") + raw, err := urlParam[string](req, "auth_id") if err != nil { return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) } @@ -290,7 +290,7 @@ func authIDFromRequest(req *http.Request) (types.AuthID, error) { // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render // the template and log an error. - err = registrationId.Validate() + registrationId, err := types.AuthIDFromString(raw) if err != nil { return "", NewHTTPError(http.StatusBadRequest, "invalid registration id", fmt.Errorf("parsing auth_id from URL: %w", err)) }