Writes and verify image
This commit is contained in:
42
clitools/pkg/controller/osimage/apply.go
Normal file
42
clitools/pkg/controller/osimage/apply.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package osimage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func ApplyImageStreamed(ctx context.Context, opts ApplyOptions) (*ApplyResult, error) {
|
||||
if err := ValidateApplyOptions(opts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := CheckTargetSafe(opts.TargetPath, opts.ExpectedRawSize); err != nil {
|
||||
return nil, fmt.Errorf("unsafe target %q: %w", opts.TargetPath, err)
|
||||
}
|
||||
|
||||
src, closeFn, err := OpenDecompressedHTTPStream(ctx, opts.URL, opts.HTTPTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open source stream: %w", err)
|
||||
}
|
||||
defer closeFn()
|
||||
|
||||
written, err := WriteStreamToTarget(ctx, src, opts.TargetPath, opts.ExpectedRawSize, opts.BufferSize, opts.Progress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write target: %w", err)
|
||||
}
|
||||
|
||||
sum, err := VerifyTargetSHA256(ctx, opts.TargetPath, opts.ExpectedRawSize, opts.BufferSize, opts.Progress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verify target: %w", err)
|
||||
}
|
||||
|
||||
if err := VerifySHA256(sum, opts.ExpectedRawSHA256); err != nil {
|
||||
return nil, fmt.Errorf("final disk checksum mismatch: %w", err)
|
||||
}
|
||||
|
||||
return &ApplyResult{
|
||||
BytesWritten: written,
|
||||
VerifiedSHA256: sum,
|
||||
VerificationOK: true,
|
||||
}, nil
|
||||
}
|
||||
15
clitools/pkg/controller/osimage/helpers.go
Normal file
15
clitools/pkg/controller/osimage/helpers.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package osimage
|
||||
|
||||
func PercentOf(done, total int64) int64 {
|
||||
if total <= 0 {
|
||||
return 0
|
||||
}
|
||||
p := (done * 100) / total
|
||||
if p < 0 {
|
||||
return 0
|
||||
}
|
||||
if p > 100 {
|
||||
return 100
|
||||
}
|
||||
return p
|
||||
}
|
||||
121
clitools/pkg/controller/osimage/progress.go
Normal file
121
clitools/pkg/controller/osimage/progress.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package osimage
|
||||
|
||||
import (
|
||||
"k8s.io/klog/v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
type progressState struct {
|
||||
lastTime time.Time
|
||||
lastPercent int64
|
||||
lastBucket int64
|
||||
}
|
||||
|
||||
type ProgressLogger struct {
|
||||
minInterval time.Duration
|
||||
bucketSize int64
|
||||
states map[string]*progressState
|
||||
}
|
||||
|
||||
func NewProgressLogger(minSeconds int, bucketSize int64) *ProgressLogger {
|
||||
if minSeconds < 0 {
|
||||
minSeconds = 0
|
||||
}
|
||||
if bucketSize <= 0 {
|
||||
bucketSize = 10
|
||||
}
|
||||
|
||||
return &ProgressLogger{
|
||||
minInterval: time.Duration(minSeconds) * time.Second,
|
||||
bucketSize: bucketSize,
|
||||
states: make(map[string]*progressState),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ProgressLogger) state(stage string) *progressState {
|
||||
s, ok := l.states[stage]
|
||||
if ok {
|
||||
return s
|
||||
}
|
||||
s = &progressState{
|
||||
lastPercent: -1,
|
||||
lastBucket: -1,
|
||||
}
|
||||
l.states[stage] = s
|
||||
return s
|
||||
}
|
||||
|
||||
func (l *ProgressLogger) Log(p Progress) {
|
||||
if p.BytesTotal <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
percent := PercentOf(p.BytesComplete, p.BytesTotal)
|
||||
|
||||
now := time.Now()
|
||||
bucket := percent / l.bucketSize
|
||||
s := l.state(p.Stage)
|
||||
|
||||
// Always log first visible progress
|
||||
if s.lastPercent == -1 {
|
||||
s.lastPercent = percent
|
||||
s.lastBucket = bucket
|
||||
s.lastTime = now
|
||||
klog.V(4).InfoS(p.Stage, "progress", percent)
|
||||
return
|
||||
}
|
||||
|
||||
// Always log completion once
|
||||
if percent == 100 && s.lastPercent < 100 {
|
||||
s.lastPercent = 100
|
||||
s.lastBucket = 100 / l.bucketSize
|
||||
s.lastTime = now
|
||||
klog.V(4).InfoS(p.Stage, "progress", 100)
|
||||
return
|
||||
}
|
||||
|
||||
// Log if we crossed a new milestone bucket
|
||||
if bucket > s.lastBucket {
|
||||
s.lastPercent = percent
|
||||
s.lastBucket = bucket
|
||||
s.lastTime = now
|
||||
klog.V(4).InfoS(p.Stage, "progress", percent)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise allow a timed refresh if progress moved
|
||||
if now.Sub(s.lastTime) >= l.minInterval && percent > s.lastPercent {
|
||||
s.lastPercent = percent
|
||||
s.lastTime = now
|
||||
klog.V(4).InfoS(p.Stage, "progress", percent)
|
||||
}
|
||||
}
|
||||
|
||||
type TimeBasedUpdater struct {
|
||||
interval time.Duration
|
||||
lastRun time.Time
|
||||
}
|
||||
|
||||
func NewTimeBasedUpdater(seconds int) *TimeBasedUpdater {
|
||||
if seconds <= 0 {
|
||||
seconds = 15
|
||||
}
|
||||
return &TimeBasedUpdater{
|
||||
interval: time.Duration(seconds) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *TimeBasedUpdater) Run(fn func() error) error {
|
||||
now := time.Now()
|
||||
|
||||
if !u.lastRun.IsZero() && now.Sub(u.lastRun) < u.interval {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := fn(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.lastRun = now
|
||||
return nil
|
||||
}
|
||||
50
clitools/pkg/controller/osimage/stream.go
Normal file
50
clitools/pkg/controller/osimage/stream.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package osimage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
func OpenDecompressedHTTPStream(ctx context.Context, url string, timeout time.Duration) (io.Reader, func() error, error) {
|
||||
if url == "" {
|
||||
return nil, nil, fmt.Errorf("url is required")
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Minute
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("http get %q: %w", url, err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
resp.Body.Close()
|
||||
return nil, nil, fmt.Errorf("unexpected status: %s", resp.Status)
|
||||
}
|
||||
|
||||
dec, err := zstd.NewReader(resp.Body)
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return nil, nil, fmt.Errorf("create zstd decoder: %w", err)
|
||||
}
|
||||
|
||||
closeFn := func() error {
|
||||
dec.Close()
|
||||
return resp.Body.Close()
|
||||
}
|
||||
|
||||
return dec, closeFn, nil
|
||||
}
|
||||
29
clitools/pkg/controller/osimage/types.go
Normal file
29
clitools/pkg/controller/osimage/types.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package osimage
|
||||
|
||||
import "time"
|
||||
|
||||
type ApplyOptions struct {
|
||||
URL string
|
||||
TargetPath string
|
||||
ExpectedRawSHA256 string
|
||||
ExpectedRawSize int64
|
||||
|
||||
HTTPTimeout time.Duration
|
||||
BufferSize int
|
||||
|
||||
Progress ProgressFunc
|
||||
}
|
||||
|
||||
type Progress struct {
|
||||
Stage string
|
||||
BytesComplete int64
|
||||
BytesTotal int64
|
||||
}
|
||||
|
||||
type ProgressFunc func(Progress)
|
||||
|
||||
type ApplyResult struct {
|
||||
BytesWritten int64
|
||||
VerifiedSHA256 string
|
||||
VerificationOK bool
|
||||
}
|
||||
104
clitools/pkg/controller/osimage/verify.go
Normal file
104
clitools/pkg/controller/osimage/verify.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package osimage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func VerifyTargetSHA256(ctx context.Context, targetPath string, expectedSize int64,
|
||||
bufferSize int, progress ProgressFunc) (string, error) {
|
||||
if targetPath == "" {
|
||||
return "", fmt.Errorf("target path is required")
|
||||
}
|
||||
if expectedSize <= 0 {
|
||||
return "", fmt.Errorf("expected raw size is required for verification")
|
||||
}
|
||||
if bufferSize <= 0 {
|
||||
bufferSize = 4 * 1024 * 1024
|
||||
}
|
||||
|
||||
f, err := os.Open(targetPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("open target for verify: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
buf := make([]byte, bufferSize)
|
||||
|
||||
var readTotal int64
|
||||
limited := io.LimitReader(f, expectedSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := limited.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := h.Write(buf[:n]); werr != nil {
|
||||
return "", fmt.Errorf("hash target: %w", werr)
|
||||
}
|
||||
readTotal += int64(n)
|
||||
if progress != nil {
|
||||
progress(Progress{
|
||||
Stage: "verify",
|
||||
BytesComplete: readTotal,
|
||||
BytesTotal: expectedSize,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read target: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if readTotal != expectedSize {
|
||||
return "", fmt.Errorf("verify size mismatch: got %d want %d", readTotal, expectedSize)
|
||||
}
|
||||
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func ValidateApplyOptions(opts ApplyOptions) error {
|
||||
if opts.URL == "" {
|
||||
return fmt.Errorf("url is required")
|
||||
}
|
||||
if opts.TargetPath == "" {
|
||||
return fmt.Errorf("target path is required")
|
||||
}
|
||||
if opts.ExpectedRawSHA256 == "" {
|
||||
return fmt.Errorf("expected raw sha256 is required")
|
||||
}
|
||||
if opts.ExpectedRawSize <= 0 {
|
||||
return fmt.Errorf("expected raw size must be > 0")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func VerifySHA256(got, expected string) error {
|
||||
expected = NormalizeSHA256(expected)
|
||||
if expected == "" {
|
||||
return nil
|
||||
}
|
||||
got = NormalizeSHA256(got)
|
||||
if got != expected {
|
||||
return fmt.Errorf("sha256 mismatch: got %s want %s", got, expected)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeSHA256(s string) string {
|
||||
return strings.ToLower(strings.TrimSpace(s))
|
||||
}
|
||||
35
clitools/pkg/controller/osimage/verify_safe.go
Normal file
35
clitools/pkg/controller/osimage/verify_safe.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build !dev
|
||||
|
||||
package osimage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func CheckTargetSafe(targetPath string, expectedRawSize int64) error {
|
||||
|
||||
if !strings.HasPrefix(targetPath, "/dev/") {
|
||||
return fmt.Errorf("target must be a device path under /dev")
|
||||
}
|
||||
|
||||
st, err := os.Stat(targetPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat target: %w", err)
|
||||
}
|
||||
|
||||
mode := st.Mode()
|
||||
if mode&os.ModeDevice == 0 {
|
||||
return fmt.Errorf("target is not a device")
|
||||
}
|
||||
|
||||
// TODO: Add stronger checks
|
||||
// - EnsureNotMounted(targetPath)
|
||||
// - EnsureNotCurrentRoot(targetPath)
|
||||
// - EnsurePartitionNotWholeDisk(targetPath)
|
||||
// - EnsureCapacity(targetPath, expectedRawSize)
|
||||
|
||||
_ = expectedRawSize
|
||||
return nil
|
||||
}
|
||||
7
clitools/pkg/controller/osimage/verify_unsafe.go
Normal file
7
clitools/pkg/controller/osimage/verify_unsafe.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build dev
|
||||
|
||||
package osimage
|
||||
|
||||
func CheckTargetSafe(targetPath string, expectedRawSize int64) error {
|
||||
return nil
|
||||
}
|
||||
82
clitools/pkg/controller/osimage/write.go
Normal file
82
clitools/pkg/controller/osimage/write.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package osimage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
func WriteStreamToTarget(ctx context.Context,
|
||||
src io.Reader,
|
||||
targetPath string,
|
||||
expectedSize int64, bufferSize int,
|
||||
progress ProgressFunc,
|
||||
) (int64, error) {
|
||||
if targetPath == "" {
|
||||
return 0, fmt.Errorf("target path is required")
|
||||
}
|
||||
if bufferSize <= 0 {
|
||||
bufferSize = 4 * 1024 * 1024
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(targetPath, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open target: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
written, err := copyWithProgressBuffer(ctx, f, src, expectedSize, "flash", progress, make([]byte, bufferSize))
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
|
||||
if expectedSize > 0 && written != expectedSize {
|
||||
return written, fmt.Errorf("written size mismatch: got %d want %d", written, expectedSize)
|
||||
}
|
||||
|
||||
if err := f.Sync(); err != nil {
|
||||
return written, fmt.Errorf("sync target: %w", err)
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func copyWithProgressBuffer(ctx context.Context, dst io.Writer, src io.Reader, total int64, stage string, progress ProgressFunc, buf []byte) (int64, error) {
|
||||
var written int64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return written, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
if progress != nil {
|
||||
progress(Progress{
|
||||
Stage: stage,
|
||||
BytesComplete: written,
|
||||
BytesTotal: total,
|
||||
})
|
||||
}
|
||||
}
|
||||
if ew != nil {
|
||||
return written, ew
|
||||
}
|
||||
if nw != nr {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if er != nil {
|
||||
if er == io.EOF {
|
||||
return written, nil
|
||||
}
|
||||
return written, fmt.Errorf("copy %s: %w", stage, er)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user