Files
ForgeBucket/internal/domain/sshserver/session.go
T

200 lines
6.2 KiB
Go

package sshserver
import (
"encoding/binary"
"fmt"
"io"
"log"
"os/exec"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
"github.com/forgeo/forgebucket/internal/models"
)
// handleSession processes a single SSH session channel: waits for an exec
// request, dispatches to the appropriate git subcommand, then exits.
func (s *Server) handleSession(ch ssh.Channel, reqs <-chan *ssh.Request, username string) {
defer ch.Close()
for req := range reqs {
if req.Type != "exec" {
if req.WantReply {
req.Reply(false, nil) //nolint:errcheck
}
continue
}
cmdStr, err := parseExecPayload(req.Payload)
if err != nil {
req.Reply(false, nil) //nolint:errcheck
return
}
req.Reply(true, nil) //nolint:errcheck
exitCode := s.runGitCommand(ch, username, cmdStr)
sendExitStatus(ch, uint32(exitCode))
return
}
}
// runGitCommand parses the SSH exec command string, validates it, resolves the
// repo, checks permissions, and runs the git subprocess.
func (s *Server) runGitCommand(ch ssh.Channel, username, cmdStr string) int {
gitCmd, repoArg, err := parseGitCommand(cmdStr)
if err != nil {
fmt.Fprintf(ch.Stderr(), "error: %v\n", err)
return 1
}
// Resolve owner/repo from the path argument (e.g. "/alice/myrepo.git" or "alice/myrepo.git")
path := strings.TrimPrefix(strings.TrimSuffix(repoArg, ".git"), "/")
parts := strings.SplitN(path, "/", 2)
if len(parts) != 2 {
fmt.Fprintf(ch.Stderr(), "error: invalid repository path\n")
return 1
}
ownerName, repoName := parts[0], parts[1]
repo, err := s.resolveRepo(ownerName, repoName)
if err != nil {
fmt.Fprintf(ch.Stderr(), "error: repository not found\n")
return 1
}
// Check permissions.
if gitCmd == "receive-pack" {
if !s.hasPermission(repo, username, "write") {
fmt.Fprintf(ch.Stderr(), "error: you do not have write access to this repository\n")
return 1
}
} else {
// upload-pack: public repos are accessible to all; private repos require read.
if repo.IsPrivate && !s.hasPermission(repo, username, "read") {
fmt.Fprintf(ch.Stderr(), "error: you do not have read access to this repository\n")
return 1
}
}
// Exec the git subcommand against the bare repo path on disk.
// The disk path comes from the DB — never from user input.
cmd := exec.Command("git", gitCmd, repo.DiskPath)
cmd.Dir = filepath.Clean(repo.DiskPath)
cmd.Env = []string{"GIT_TERMINAL_PROMPT=0", "HOME=/tmp"}
cmd.Stdin = ch
cmd.Stdout = ch
cmd.Stderr = ch.Stderr()
if err := cmd.Run(); err != nil {
log.Printf("sshserver: git %s for %s/%s: %v", gitCmd, ownerName, repoName, err)
if exitErr, ok := err.(*exec.ExitError); ok {
return exitErr.ExitCode()
}
return 1
}
return 0
}
// resolveRepo looks up a repository by owner name (user or workspace) and repo name.
func (s *Server) resolveRepo(ownerName, repoName string) (*models.Repository, error) {
var u models.User
if found, _ := s.db.Where("username = ?", ownerName).Get(&u); found {
var repo models.Repository
if found2, _ := s.db.Where("owner_id = ? AND name = ?", u.ID, repoName).Get(&repo); found2 {
return &repo, nil
}
}
var ws models.Workspace
if found, _ := s.db.Where("handle = ?", ownerName).Get(&ws); found {
var repo models.Repository
if found2, _ := s.db.Where("workspace_id = ? AND name = ?", ws.ID, repoName).Get(&repo); found2 {
return &repo, nil
}
}
return nil, fmt.Errorf("not found")
}
// hasPermission checks whether username has at least the required permission on repo.
func (s *Server) hasPermission(repo *models.Repository, username, required string) bool {
var u models.User
if found, _ := s.db.Where("username = ?", username).Get(&u); !found {
return false
}
if u.ID == repo.OwnerID {
return true
}
var m models.RepoMember
if found, _ := s.db.Where("repo_id = ? AND user_id = ?", repo.ID, u.ID).Get(&m); !found {
return false
}
rank := map[string]int{"read": 1, "write": 2, "admin": 3}
return rank[m.Permission] >= rank[required]
}
// parseGitCommand splits the SSH exec command string into the git subcommand
// and the repo path argument. Only upload-pack and receive-pack are permitted.
//
// Accepts both "git-upload-pack '/path'" and "git upload-pack /path" forms.
func parseGitCommand(cmdStr string) (gitCmd string, repoPath string, err error) {
cmdStr = strings.TrimSpace(cmdStr)
var candidate string
var rest string
if strings.HasPrefix(cmdStr, "git-upload-pack") {
candidate = "upload-pack"
rest = strings.TrimPrefix(cmdStr, "git-upload-pack")
} else if strings.HasPrefix(cmdStr, "git-receive-pack") {
candidate = "receive-pack"
rest = strings.TrimPrefix(cmdStr, "git-receive-pack")
} else if strings.HasPrefix(cmdStr, "git upload-pack") {
candidate = "upload-pack"
rest = strings.TrimPrefix(cmdStr, "git upload-pack")
} else if strings.HasPrefix(cmdStr, "git receive-pack") {
candidate = "receive-pack"
rest = strings.TrimPrefix(cmdStr, "git receive-pack")
} else {
return "", "", fmt.Errorf("unsupported command: only git-upload-pack and git-receive-pack are allowed")
}
// Strip surrounding whitespace and single quotes from the path argument.
rest = strings.TrimSpace(rest)
rest = strings.Trim(rest, "'\"")
if rest == "" {
return "", "", fmt.Errorf("missing repository path argument")
}
return candidate, rest, nil
}
// parseExecPayload decodes the SSH exec request payload: 4-byte big-endian
// length followed by the command string.
func parseExecPayload(payload []byte) (string, error) {
if len(payload) < 4 {
return "", fmt.Errorf("exec payload too short")
}
length := binary.BigEndian.Uint32(payload[:4])
if int(length) > len(payload)-4 {
return "", fmt.Errorf("exec payload length mismatch")
}
return string(payload[4 : 4+length]), nil
}
// sendExitStatus sends an SSH exit-status channel request.
func sendExitStatus(ch ssh.Channel, code uint32) {
msg := struct{ Status uint32 }{code}
ch.SendRequest("exit-status", false, ssh.Marshal(msg)) //nolint:errcheck
}
// Stderr returns the stderr stream of an SSH channel.
// The ssh.Channel type embeds io.ReadWriteCloser for stdout/stdin;
// Stderr() is defined on *ssh.channel but not the interface — use a type assertion.
func init() {
// Compile-time interface check: ssh.Channel must have Stderr() method.
var _ interface{ Stderr() io.ReadWriter } = (ssh.Channel)(nil)
}