#!/usr/bin/env python3 # Simple utility for programming SPI flash via a UART shell in the bootloader. # The bootloader possesses a 1-page buffer which we can write to, read from and # checksum. # It also possesses commands for transfers between this buffer and the SPI flash, # flash erasure, and launching the 2nd stage boot code. import argparse import os import serial import serial.tools.list_ports import sys import time PAGESIZE = 2 ** 8 SECTORSIZE = 2 ** 12 BLOCKSIZE = 2 ** 16 SRAM_LOAD_ADDR = 0x2 << 28 SRAM_EXEC_ADDR = SRAM_LOAD_ADDR + 0xc0 class ProtocolError(Exception): pass class CMD: NOP = '\n'.encode() WRITE_BUF = 'w'.encode() READ_BUF = 'r'.encode() GET_CHECKSUM = 'c'.encode() SET_ADDR = 'a'.encode() WRITE_FLASH = 'W'.encode() READ_FLASH = 'R'.encode() ERASE_SECTOR = 'E'.encode() ERASE_BLOCK = 'B'.encode() BOOT_2ND = '2'.encode() LOAD_MEM = 'l'.encode() EXEC_MEM = 'x'.encode() ACK = ':'.encode() def reset_shell(port): while True: port.write(CMD.NOP * (PAGESIZE + 1)) port.flushOutput() time.sleep(0.02) port.flushInput() port.write(CMD.NOP) time.sleep(0.02) if port.readable() and port.read_all().endswith(CMD.ACK): break def check_ack(port): resp = port.read() if len(resp) == 0 or resp != CMD.ACK: print("Got '{!r}'".format(resp)) raise ProtocolError() def write_buf(port, data, check=True): assert(len(data) == 256) port.write(CMD.WRITE_BUF) port.write(data) if check: check_ack(port) # TODO checksum def read_buf(port, pipelined=False): if not pipelined: port.write(CMD.READ_BUF) resp = port.read(PAGESIZE) if len(resp) != PAGESIZE: raise ProtocolError() return resp # TODO checksum def set_addr(port, addr): addr_b = bytes([ (addr >> 16) & 0xff, (addr >> 8) & 0xff, (addr >> 0) & 0xff ]) port.write(CMD.SET_ADDR + addr_b) echo = port.read(3) if echo != addr_b: raise ProtocolError() check_ack(port) def progress(header, frac, width=40): n = int(width * frac) sys.stdout.write("\r" + header + "▕" + "▒" * n + " " * (width - n) + "▏") def read_flash(port, addr, size): set_addr(port, addr) data = bytes() addr_range = range(addr, addr + size, PAGESIZE) for a in addr_range: port.write(CMD.READ_FLASH) port.write(CMD.READ_BUF) # Should have space for 2 cmds in FIFO. Hoist this one up from read_buf() progress("Read: ", (a - addr) / size) check_ack(port) data += read_buf(port, pipelined=True) progress("Read: ", 1) if len(data) > size: data = data[:size] # ouch return data def write_flash(port, addr, data): assert(addr % PAGESIZE == 0) if len(data) % PAGESIZE != 0: data += bytes(PAGESIZE - len(data) % PAGESIZE) set_addr(port, addr) for i in range(len(data) // PAGESIZE): write_buf(port, data[i * PAGESIZE : (i + 1) * PAGESIZE], check=False) port.write(CMD.WRITE_FLASH) progress("Write: ", i / (len(data) // PAGESIZE)) check_ack(port) check_ack(port) progress("Write: ", 1) def erase_flash(port, addr, size): end = addr + size if end % SECTORSIZE != 0: end += SECTORSIZE - end % SECTORSIZE start = addr - addr % SECTORSIZE set_addr(port, start) a = start while a < end: progress("Erase: ", (a - start) / (end - start)) if end - a >= BLOCKSIZE and a % BLOCKSIZE == 0: port.write(CMD.ERASE_BLOCK) a += BLOCKSIZE else: port.write(CMD.ERASE_SECTOR) a += SECTORSIZE check_ack(port) progress("Erase: ", 1) def run_2nd_stage(port): set_addr(port, 0x123456) port.write(CMD.BOOT_2ND) check_ack(port) def load_mem(port, data): set_addr(port, 0xb00000 | SRAM_LOAD_ADDR) port.write(CMD.LOAD_MEM + bytes([ (len(data) >> 16) & 0xff, (len(data) >> 8) & 0xff, (len(data) >> 0) & 0xff ])) check_ack(port) port.write(data) # yup check_ack(port) def exec_mem(port): set_addr(port, 0xb00000 | SRAM_EXEC_ADDR) port.write(CMD.EXEC_MEM) check_ack(port) def any_int(x): return int(x, 0) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("file", nargs="?", help="Filename to read from or dump to.") parser.add_argument("--uart", "-u", help="Path to UART device (e.g. /dev/ttyUSB0)") parser.add_argument("--baud", "-b", help="Baud rate for UART. Default 1 Mbaud") parser.add_argument("--write", "-w", action="store_true", help="Write to flash (default is read)") parser.add_argument("--verify", "-v", action="store_true", help="Verify contents after programming (use with --write)") parser.add_argument("--run", "-r", action="store_true", help="Run code from flash after programming") parser.add_argument("--execute", "-x", action="store_true", help="Load code directly into SRAM and execute it. Not to be combined with -w,-v,-r") parser.add_argument("--start", "-s", type=any_int, help="Base address to start read/write") parser.add_argument("--len", "-l", type=any_int, help="Number of bytes to read, or override file length for write.") args = parser.parse_args() # Do some simple parameter checking to make it harder to accidentally trash your device if args.write and args.file is None: sys.exit("Must specify filename for write") if args.start is None: args.start = 0 if args.uart is None: args.uart = sorted(p.device for p in serial.tools.list_ports.comports())[-1] if args.baud is None: args.baud = 1000000 if args.verify and not args.write: sys.exit("Verify is only valid for a write operation") if args.run and not args.write: sys.exit("Run is only valid for a write operation") if args.write and (args.start % PAGESIZE != 0): sys.exit("Writes must be aligned on a {}-byte boundary.".format(PAGESIZE)) if args.execute and (args.write or args.verify or args.run or args.start or args.len): sys.exit("--execute is not compatible with flash-related arguments") if args.write or args.execute: try: filesize = os.stat(args.file).st_size if args.len is None: args.len = filesize except: sys.exit("Could not open file '{}'".format(args.file)) else: if args.len is None: args.len = PAGESIZE # Don't want to overwrite their image if they forget -w if not (args.write or args.execute) and args.file and os.path.exists(args.file): resp = "" while not resp in ["y", "n"]: resp = input("File '{}' exists. Overwrite? (y/n) ".format(args.file)) resp = resp.lower().strip() if resp == "n": sys.exit(0) # Summarise what we're about to do if args.execute: print("Loading {} bytes to SRAM and running".format(args.len)) else: print("{} {} bytes {} {}, starting at address 0x{:06x}".format( "Writing" + (" and verifying" if args.verify else "") if args.write else "Reading", args.len, "from" if args.write else "to", "stdout" if args.file is None else "file " + args.file, args.start, ", and verifying" if args.verify else "" )) # And then do it # need to allow for page erase, upward of 60 ms. (Uh, actually it seems to be much longer) port = serial.Serial(args.uart, args.baud, timeout=1) print("Waiting for bootloader on {}...".format(port.name)) reset_shell(port) print("") if args.execute: print("Loading...") load_mem(port, open(args.file, "rb").read()) print("Load ok, running...") exec_mem(port) elif args.write: data = open(args.file, "rb").read() if len(data) > args.len: data = data[:args.len] else: data = data + bytes(args.len - len(data)) # zero padding start = time.time() erase_flash(port, args.start, args.len) print(" Took {:.1f} s\n".format(time.time() - start)) start = time.time() write_flash(port, args.start, data) print(" Took {:.1f} s\n".format(time.time() - start)) if args.verify: start = time.time() readback = read_flash(port, args.start, args.len) print(" Took {:.1f} s".format(time.time() - start)) if readback != data: sys.exit("Verification failed.") # TODO: better info if args.run: print("Launching flash second stage") run_2nd_stage(port) print("Done") else: start = time.time() data = read_flash(port, args.start, args.len) print(" Took {:.1f} s\nDone".format(time.time() - start)) if args.file: open(args.file, "wb").write(data) else: for i, byte in enumerate(data): if i % 8 == 0: sys.stdout.write("{:06x}: ".format(args.start + i)) sys.stdout.write("{:02x}".format(byte)) sys.stdout.write("\n" if i % 8 == 7 else " ") sys.stdout.write("\n")